Skip to content

Commit 96cb2e2

Browse files
glerchundijohanbrandhorst
authored andcommitted
gengateway: allow opting out patch feature
Closes #839
1 parent a6d3ad2 commit 96cb2e2

File tree

4 files changed

+98
-8
lines changed

4 files changed

+98
-8
lines changed

protoc-gen-grpc-gateway/gengateway/generator.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ type generator struct {
3232
useRequestContext bool
3333
registerFuncSuffix string
3434
pathType pathType
35+
allowPatchFeature bool
3536
}
3637

3738
// New returns a new generator which generates grpc gateway files.
38-
func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, pathTypeString string) gen.Generator {
39+
func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, pathTypeString string, allowPatchFeature bool) gen.Generator {
3940
var imports []descriptor.GoPackage
4041
for _, pkgpath := range []string{
4142
"io",
@@ -82,6 +83,7 @@ func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, p
8283
useRequestContext: useRequestContext,
8384
registerFuncSuffix: registerFuncSuffix,
8485
pathType: pathType,
86+
allowPatchFeature: allowPatchFeature,
8587
}
8688
}
8789

@@ -142,6 +144,7 @@ func (g *generator) generate(file *descriptor.File) (string, error) {
142144
Imports: imports,
143145
UseRequestContext: g.useRequestContext,
144146
RegisterFuncSuffix: g.registerFuncSuffix,
147+
AllowPatchFeature: g.allowPatchFeature,
145148
}
146149
return applyTemplate(params, g.reg)
147150
}

protoc-gen-grpc-gateway/gengateway/template.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ type param struct {
1717
Imports []descriptor.GoPackage
1818
UseRequestContext bool
1919
RegisterFuncSuffix string
20+
AllowPatchFeature bool
2021
}
2122

2223
type binding struct {
2324
*descriptor.Binding
24-
Registry *descriptor.Registry
25+
Registry *descriptor.Registry
26+
AllowPatchFeature bool
2527
}
2628

2729
// GetBodyFieldPath returns the binding body's fieldpath.
@@ -153,7 +155,11 @@ func applyTemplate(p param, reg *descriptor.Registry) (string, error) {
153155
meth.Name = &methName
154156
for _, b := range meth.Bindings {
155157
methodWithBindingsSeen = true
156-
if err := handlerTemplate.Execute(w, binding{Binding: b, Registry: reg}); err != nil {
158+
if err := handlerTemplate.Execute(w, binding{
159+
Binding: b,
160+
Registry: reg,
161+
AllowPatchFeature: p.AllowPatchFeature,
162+
}); err != nil {
157163
return "", err
158164
}
159165
}
@@ -268,6 +274,7 @@ func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx cont
268274
`))
269275

270276
_ = template.Must(handlerTemplate.New("client-rpc-request-func").Parse(`
277+
{{$AllowPatchFeature := .AllowPatchFeature}}
271278
{{if .HasQueryParam}}
272279
var (
273280
filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}} = {{.QueryParamFilter}}
@@ -284,7 +291,7 @@ var (
284291
if err := marshaler.NewDecoder(newReader()).Decode(&{{.Body.AssignableExpr "protoReq"}}); err != nil && err != io.EOF {
285292
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
286293
}
287-
{{- if and (eq (.HTTPMethod) "PATCH") (.FieldMaskField)}}
294+
{{- if and $AllowPatchFeature (and (eq (.HTTPMethod) "PATCH") (.FieldMaskField))}}
288295
if protoReq.{{.FieldMaskField}} != nil && len(protoReq.{{.FieldMaskField}}.GetPaths()) > 0 {
289296
runtime.CamelCaseFieldMask(protoReq.{{.FieldMaskField}})
290297
} {{if not (eq "*" .GetBodyFieldPath)}} else {

protoc-gen-grpc-gateway/gengateway/template_test.go

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func TestApplyTemplateHeader(t *testing.T) {
7777
},
7878
},
7979
}
80-
got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler"}, descriptor.NewRegistry())
80+
got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
8181
if err != nil {
8282
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
8383
return
@@ -222,7 +222,7 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) {
222222
},
223223
},
224224
}
225-
got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler"}, descriptor.NewRegistry())
225+
got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
226226
if err != nil {
227227
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
228228
return
@@ -383,7 +383,7 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) {
383383
},
384384
},
385385
}
386-
got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler"}, descriptor.NewRegistry())
386+
got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
387387
if err != nil {
388388
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
389389
return
@@ -402,3 +402,82 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) {
402402
}
403403
}
404404
}
405+
406+
func TestAllowPatchFeature(t *testing.T) {
407+
updateMaskDesc := &protodescriptor.FieldDescriptorProto{
408+
Name: proto.String("UpdateMask"),
409+
Label: protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
410+
Type: protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
411+
TypeName: proto.String(".google.protobuf.FieldMask"),
412+
Number: proto.Int32(1),
413+
}
414+
msgdesc := &protodescriptor.DescriptorProto{
415+
Name: proto.String("ExampleMessage"),
416+
Field: []*protodescriptor.FieldDescriptorProto{updateMaskDesc},
417+
}
418+
meth := &protodescriptor.MethodDescriptorProto{
419+
Name: proto.String("Example"),
420+
InputType: proto.String("ExampleMessage"),
421+
OutputType: proto.String("ExampleMessage"),
422+
}
423+
svc := &protodescriptor.ServiceDescriptorProto{
424+
Name: proto.String("ExampleService"),
425+
Method: []*protodescriptor.MethodDescriptorProto{meth},
426+
}
427+
msg := &descriptor.Message{
428+
DescriptorProto: msgdesc,
429+
}
430+
updateMaskField := &descriptor.Field{
431+
Message: msg,
432+
FieldDescriptorProto: updateMaskDesc,
433+
}
434+
msg.Fields = append(msg.Fields, updateMaskField)
435+
file := descriptor.File{
436+
FileDescriptorProto: &protodescriptor.FileDescriptorProto{
437+
Name: proto.String("example.proto"),
438+
Package: proto.String("example"),
439+
MessageType: []*protodescriptor.DescriptorProto{msgdesc},
440+
Service: []*protodescriptor.ServiceDescriptorProto{svc},
441+
},
442+
GoPkg: descriptor.GoPackage{
443+
Path: "example.com/path/to/example/example.pb",
444+
Name: "example_pb",
445+
},
446+
Messages: []*descriptor.Message{msg},
447+
Services: []*descriptor.Service{
448+
{
449+
ServiceDescriptorProto: svc,
450+
Methods: []*descriptor.Method{
451+
{
452+
MethodDescriptorProto: meth,
453+
RequestType: msg,
454+
ResponseType: msg,
455+
Bindings: []*descriptor.Binding{
456+
{
457+
HTTPMethod: "PATCH",
458+
Body: &descriptor.Body{FieldPath: nil},
459+
},
460+
},
461+
},
462+
},
463+
},
464+
},
465+
}
466+
want := "if protoReq.UpdateMask != nil && len(protoReq.UpdateMask.GetPaths()) > 0 {\n"
467+
for _, allowPatchFeature := range []bool{true, false} {
468+
got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: allowPatchFeature}, descriptor.NewRegistry())
469+
if err != nil {
470+
t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
471+
return
472+
}
473+
if allowPatchFeature {
474+
if !strings.Contains(got, want) {
475+
t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
476+
}
477+
} else {
478+
if strings.Contains(got, want) {
479+
t.Errorf("applyTemplate(%#v) = %s; want to _not_ contain %s", file, got, want)
480+
}
481+
}
482+
}
483+
}

protoc-gen-grpc-gateway/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ var (
3232
pathType = flag.String("paths", "", "specifies how the paths of generated files are structured")
3333
allowRepeatedFieldsInBody = flag.Bool("allow_repeated_fields_in_body", false, "allows to use repeated field in `body` and `response_body` field of `google.api.http` annotation option")
3434
repeatedPathParamSeparator = flag.String("repeated_path_param_separator", "csv", "configures how repeated fields should be split. Allowed values are `csv`, `pipes`, `ssv` and `tsv`.")
35+
allowPatchFeature = flag.Bool("allow_patch_feature", true, "determines whether to use PATCH feature involving update masks (using google.protobuf.FieldMask).")
3536
versionFlag = flag.Bool("version", false, "print the current verison")
3637
)
3738

@@ -79,7 +80,7 @@ func main() {
7980
}
8081
}
8182

82-
g := gengateway.New(reg, *useRequestContext, *registerFuncSuffix, *pathType)
83+
g := gengateway.New(reg, *useRequestContext, *registerFuncSuffix, *pathType, *allowPatchFeature)
8384

8485
if *grpcAPIConfiguration != "" {
8586
if err := reg.LoadGrpcAPIServiceFromYAML(*grpcAPIConfiguration); err != nil {

0 commit comments

Comments
 (0)