Skip to content

Commit 7b38baa

Browse files
authored
feat: add protoc-gen-connect-go-service-struct plugin
2 parents c73622a + 30b6c46 commit 7b38baa

File tree

6 files changed

+709
-7
lines changed

6 files changed

+709
-7
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ go.work.sum
3030
# Editor/IDE
3131
# .idea/
3232
# .vscode/
33+
34+
# Generated build files
35+
protoc-gen-connect-go-servicestruct
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# protoc-gen-connect-go-servicestruct
2+
3+
A protobuf compiler plugin that generates Go service structs with handler
4+
functions for Connect RPC services. **Requires `protoc-gen-connect-go` as a dependency.**
5+
6+
## How-tos
7+
8+
### Installation
9+
10+
```bash
11+
go install github.com/TrogonStack/protoc-gen/cmd/protoc-gen-connect-go-servicestruct@latest
12+
```
13+
14+
### Basic Usage
15+
16+
**Required**: This plugin must be used alongside `protoc-gen-connect-go` as it generates structs that implement the
17+
Connect handler interfaces:
18+
19+
```bash
20+
protoc --go_out=gen --connect-go_out=gen --connect-go-servicestruct_out=gen path/to/service.proto
21+
```
22+
23+
### With buf
24+
25+
**Required**: Add to your `buf.gen.yaml` alongside the required `buf.build/connectrpc/go` plugin:
26+
27+
```yaml
28+
version: v2
29+
plugins:
30+
- remote: buf.build/protocolbuffers/go
31+
out: gen
32+
opt: paths=source_relative
33+
- remote: buf.build/connectrpc/go:v1.18.1 # Required dependency
34+
out: gen
35+
opt: paths=source_relative
36+
- local: protoc-gen-connect-go-servicestruct
37+
out: gen
38+
opt:
39+
- paths=source_relative
40+
```
41+
42+
Files are generated into `{packagename}connect` subdirectories alongside the connect plugin output.
43+
44+
### Basic Service Implementation
45+
46+
For more realistic applications that require dependencies like database connections, use constructor functions to
47+
create handlers:
48+
49+
```go
50+
package main
51+
52+
import (
53+
"context"
54+
"fmt"
55+
"log"
56+
"net/http"
57+
58+
"connectrpc.com/connect"
59+
"github.com/jackc/pgx/v5/pgxpool"
60+
greeterv1 "example.com/greeter/gen/greeter/v1"
61+
"example.com/greeter/gen/greeter/v1/greeterv1connect"
62+
)
63+
64+
func main() {
65+
// Initialize database connection
66+
dbPool, err := pgxpool.New(context.Background(), "postgres://user:pass@localhost/db")
67+
if err != nil {
68+
log.Fatal("Failed to create connection pool:", err)
69+
}
70+
defer dbPool.Close()
71+
72+
mux := http.NewServeMux()
73+
path, handler := greeterv1connect.NewGreeterServiceHandler(&greeterv1connect.GreeterServiceStruct{
74+
GreetFunc: newGreetHandler(dbPool),
75+
GreetStreamFunc: newGreetStreamHandler(dbPool),
76+
})
77+
mux.Handle(path, handler)
78+
79+
log.Println("Greeter server starting on :8080")
80+
http.ListenAndServe(":8080", mux)
81+
}
82+
83+
func newGreetHandler(db *pgxpool.Pool) greeterv1.GreeterServiceGreetHandlerFunc {
84+
return func(ctx context.Context, req *connect.Request[greeterv1.GreetRequest]) (*connect.Response[greeterv1.GreetResponse], error) {
85+
// ...
86+
}
87+
}
88+
89+
func newGreetStreamHandler(db *pgxpool.Pool) greeterv1.GreeterServiceGreetStreamHandlerFunc {
90+
return func(ctx context.Context, stream *connect.BidiStream[greeterv1.GreetRequest, greeterv1.GreetResponse]) error {
91+
// ...
92+
}
93+
}
94+
```
95+
96+
## Explanations
97+
98+
### Overview
99+
100+
This plugin generates service structs that provide a convenient way to organize and access handler functions for
101+
Connect RPC services. **This plugin requires `protoc-gen-connect-go` (or `buf.build/connectrpc/go`) to function** - it
102+
generates structs that implement the Connect handler interfaces created by the standard Connect plugin.
103+
104+
### Features
105+
106+
- **Service Structs**: Generates public structs with handler function fields.
107+
- **Type Safety**: Uses strongly-typed handler function types.
108+
- **Interface Compatibility**: Generated structs implement the standard handler interfaces for backwards compatibility.
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
// protoc-gen-connect-go-servicestruct is a plugin for the Protobuf compiler that generates
2+
// Go service structs with handler functions. To use it, build this program and make
3+
// it available on your PATH as protoc-gen-connect-go-servicestruct.
4+
//
5+
// The 'connect-go-servicestruct' suffix becomes part of the arguments for the Protobuf
6+
// compiler. To generate service structs using protoc:
7+
//
8+
// protoc --go_out=gen --connect-go-servicestruct_out=gen path/to/file.proto
9+
//
10+
// With [buf], your buf.gen.yaml will look like this:
11+
//
12+
// version: v2
13+
// plugins:
14+
// - local: protoc-gen-go
15+
// out: gen
16+
// - local: protoc-gen-connect-go-servicestruct
17+
// out: gen
18+
//
19+
// This generates service struct definitions for the Protobuf services
20+
// defined by file.proto. If file.proto defines the foov1 Protobuf package, the
21+
// invocations above will write output to:
22+
//
23+
// gen/path/to/file.pb.go
24+
// gen/path/to/foov1connect/file.servicestruct.connect.go
25+
//
26+
// [buf]: https://buf.build
27+
package main
28+
29+
import (
30+
"fmt"
31+
"os"
32+
"path"
33+
"strings"
34+
35+
connect "connectrpc.com/connect"
36+
"google.golang.org/protobuf/compiler/protogen"
37+
"google.golang.org/protobuf/types/descriptorpb"
38+
"google.golang.org/protobuf/types/pluginpb"
39+
)
40+
41+
const (
42+
contextPackage = protogen.GoImportPath("context")
43+
connectPackage = protogen.GoImportPath("connectrpc.com/connect")
44+
45+
filenameSuffix = ".servicestruct.connect.go"
46+
handlerFuncSuffix = "HandlerFunc"
47+
serviceSuffix = "Struct"
48+
49+
usage = "\n\nFlags:\n -h, --help\tPrint this help and exit.\n --version\tPrint the version and exit."
50+
)
51+
52+
func main() {
53+
if len(os.Args) == 2 && os.Args[1] == "--version" {
54+
if _, err := fmt.Fprintln(os.Stdout, connect.Version); err != nil {
55+
os.Exit(1)
56+
}
57+
os.Exit(0)
58+
}
59+
if len(os.Args) == 2 && (os.Args[1] == "-h" || os.Args[1] == "--help") {
60+
if _, err := fmt.Fprintln(os.Stdout, usage); err != nil {
61+
os.Exit(1)
62+
}
63+
os.Exit(0)
64+
}
65+
if len(os.Args) != 1 {
66+
if _, err := fmt.Fprintln(os.Stderr, usage); err != nil {
67+
os.Exit(1)
68+
}
69+
os.Exit(1)
70+
}
71+
protogen.Options{}.Run(
72+
func(plugin *protogen.Plugin) error {
73+
plugin.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) | uint64(pluginpb.CodeGeneratorResponse_FEATURE_SUPPORTS_EDITIONS)
74+
plugin.SupportedEditionsMinimum = descriptorpb.Edition_EDITION_PROTO2
75+
plugin.SupportedEditionsMaximum = descriptorpb.Edition_EDITION_2023
76+
77+
for _, file := range plugin.Files {
78+
if file.Generate && len(file.Services) > 0 {
79+
generateFile(plugin, file)
80+
}
81+
}
82+
83+
return nil
84+
},
85+
)
86+
}
87+
88+
func generateFile(plugin *protogen.Plugin, file *protogen.File) {
89+
// Modify package name and paths to match connect plugin behavior
90+
connectSuffix := "connect"
91+
connectPackageName := file.GoPackageName + protogen.GoPackageName(connectSuffix)
92+
dir := path.Dir(file.GeneratedFilenamePrefix)
93+
base := path.Base(file.GeneratedFilenamePrefix)
94+
connectDir := string(file.GoPackageName) + connectSuffix
95+
filename := path.Join(dir, connectDir, base+filenameSuffix)
96+
goImportPath := protogen.GoImportPath(path.Join(
97+
string(file.GoImportPath),
98+
connectDir,
99+
))
100+
generatedFile := plugin.NewGeneratedFile(filename, goImportPath)
101+
102+
// Import the base package for message types
103+
generatedFile.Import(file.GoImportPath)
104+
105+
generatedFile.P("// Code generated by protoc-gen-connect-go-servicestruct. DO NOT EDIT.")
106+
generatedFile.P("//")
107+
generatedFile.P("// Source: ", file.Desc.Path())
108+
generatedFile.P()
109+
generatedFile.P("package ", connectPackageName)
110+
generatedFile.P()
111+
112+
generateHandlerFuncTypes(generatedFile, []*protogen.File{file})
113+
114+
for _, service := range file.Services {
115+
generateServiceStruct(generatedFile, service)
116+
}
117+
}
118+
119+
func generateHandlerFuncTypes(g *protogen.GeneratedFile, files []*protogen.File) {
120+
// Generate specific handler function types for each method
121+
for _, file := range files {
122+
for _, service := range file.Services {
123+
for _, method := range service.Methods {
124+
generateMethodHandlerType(g, service, method)
125+
}
126+
}
127+
}
128+
}
129+
130+
func generateMethodHandlerType(g *protogen.GeneratedFile, service *protogen.Service, method *protogen.Method) {
131+
typeName := service.GoName + method.GoName + handlerFuncSuffix
132+
133+
if isDeprecatedMethod(method) {
134+
g.P("//")
135+
deprecated(g)
136+
}
137+
138+
// Generate the appropriate function signature based on streaming type
139+
isStreamingClient := method.Desc.IsStreamingClient()
140+
isStreamingServer := method.Desc.IsStreamingServer()
141+
142+
switch {
143+
case isStreamingClient && isStreamingServer:
144+
// Bidirectional streaming
145+
g.P("type ", typeName, " func(", contextPackage.Ident("Context"), ", *", connectPackage.Ident("BidiStream"), "[", g.QualifiedGoIdent(method.Input.GoIdent), ", ", g.QualifiedGoIdent(method.Output.GoIdent), "]) error")
146+
case isStreamingClient && !isStreamingServer:
147+
// Client streaming
148+
g.P("type ", typeName, " func(", contextPackage.Ident("Context"), ", *", connectPackage.Ident("ClientStream"), "[", g.QualifiedGoIdent(method.Input.GoIdent), "]) (*", connectPackage.Ident("Response"), "[", g.QualifiedGoIdent(method.Output.GoIdent), "], error)")
149+
case !isStreamingClient && isStreamingServer:
150+
// Server streaming
151+
g.P("type ", typeName, " func(", contextPackage.Ident("Context"), ", *", connectPackage.Ident("Request"), "[", g.QualifiedGoIdent(method.Input.GoIdent), "], *", connectPackage.Ident("ServerStream"), "[", g.QualifiedGoIdent(method.Output.GoIdent), "]) error")
152+
default:
153+
// Unary
154+
g.P("type ", typeName, " func(", contextPackage.Ident("Context"), ", *", connectPackage.Ident("Request"), "[", g.QualifiedGoIdent(method.Input.GoIdent), "]) (*", connectPackage.Ident("Response"), "[", g.QualifiedGoIdent(method.Output.GoIdent), "], error)")
155+
}
156+
g.P()
157+
}
158+
159+
func generateServiceStruct(g *protogen.GeneratedFile, service *protogen.Service) {
160+
serviceName := fmt.Sprintf("%s%s", service.GoName, serviceSuffix)
161+
162+
if isDeprecatedService(service) {
163+
g.P("//")
164+
deprecated(g)
165+
}
166+
g.AnnotateSymbol(serviceName, protogen.Annotation{Location: service.Location})
167+
g.P("type ", serviceName, " struct {")
168+
for _, method := range service.Methods {
169+
fieldName := method.GoName + "Func"
170+
g.AnnotateSymbol(serviceName+"."+fieldName, protogen.Annotation{Location: method.Location})
171+
leadingComments(
172+
g,
173+
method.Comments.Leading,
174+
isDeprecatedMethod(method),
175+
)
176+
handlerTypeName := service.GoName + method.GoName + handlerFuncSuffix
177+
g.P(fieldName, " ", handlerTypeName)
178+
}
179+
g.P("}")
180+
g.P()
181+
182+
for _, method := range service.Methods {
183+
fieldName := method.GoName + "Func"
184+
if isDeprecatedMethod(method) {
185+
g.P("//")
186+
deprecated(g)
187+
}
188+
methodSig := method.GoName + serverSignatureParams(g, method, true)
189+
g.P("func (s *", serviceName, ") ", methodSig, " {")
190+
191+
isStreamingClient := method.Desc.IsStreamingClient()
192+
isStreamingServer := method.Desc.IsStreamingServer()
193+
194+
switch {
195+
case isStreamingClient && isStreamingServer:
196+
g.P("return s.", fieldName, "(ctx, stream)")
197+
case isStreamingClient && !isStreamingServer:
198+
g.P("return s.", fieldName, "(ctx, stream)")
199+
case !isStreamingClient && isStreamingServer:
200+
g.P("return s.", fieldName, "(ctx, req, stream)")
201+
default:
202+
g.P("return s.", fieldName, "(ctx, req)")
203+
}
204+
g.P("}")
205+
g.P()
206+
}
207+
}
208+
209+
func serverSignatureParams(g *protogen.GeneratedFile, method *protogen.Method, named bool) string {
210+
ctxName := "ctx "
211+
reqName := "req "
212+
streamName := "stream "
213+
if !named {
214+
ctxName, reqName, streamName = "", "", ""
215+
}
216+
if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() {
217+
return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " +
218+
streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("BidiStream")) +
219+
"[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" +
220+
") error"
221+
}
222+
if method.Desc.IsStreamingClient() {
223+
return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " +
224+
streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) +
225+
"[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" +
226+
") (*" + g.QualifiedGoIdent(connectPackage.Ident("Response")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "], error)"
227+
}
228+
if method.Desc.IsStreamingServer() {
229+
return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) +
230+
", " + reqName + "*" + g.QualifiedGoIdent(connectPackage.Ident("Request")) + "[" +
231+
g.QualifiedGoIdent(method.Input.GoIdent) + "], " +
232+
streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ServerStream")) +
233+
"[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" +
234+
") error"
235+
}
236+
return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) +
237+
", " + reqName + "*" + g.QualifiedGoIdent(connectPackage.Ident("Request")) + "[" +
238+
g.QualifiedGoIdent(method.Input.GoIdent) + "]) " +
239+
"(*" + g.QualifiedGoIdent(connectPackage.Ident("Response")) + "[" +
240+
g.QualifiedGoIdent(method.Output.GoIdent) + "], error)"
241+
}
242+
243+
func isDeprecatedService(service *protogen.Service) bool {
244+
serviceOptions, ok := service.Desc.Options().(*descriptorpb.ServiceOptions)
245+
return ok && serviceOptions.GetDeprecated()
246+
}
247+
248+
func isDeprecatedMethod(method *protogen.Method) bool {
249+
methodOptions, ok := method.Desc.Options().(*descriptorpb.MethodOptions)
250+
return ok && methodOptions.GetDeprecated()
251+
}
252+
253+
func leadingComments(g *protogen.GeneratedFile, comments protogen.Comments, isDeprecated bool) {
254+
if comments.String() != "" {
255+
g.P(strings.TrimSpace(comments.String()))
256+
}
257+
if isDeprecated {
258+
if comments.String() != "" {
259+
g.P("//")
260+
}
261+
deprecated(g)
262+
}
263+
}
264+
265+
func deprecated(g *protogen.GeneratedFile) {
266+
g.P("// Deprecated: do not use.")
267+
}

0 commit comments

Comments
 (0)