Skip to content

Commit 9675e5b

Browse files
sdks/go: add string utf-8 check to vet runner for serialization (#33949)
* sdks/go: utf-8 check on exported fields vet runner * sdks/go: test utf-8 check on export fields in vet * sdks/go: go doc utf-8 check comment on encode DoFn
1 parent 597c785 commit 9675e5b

File tree

3 files changed

+150
-10
lines changed

3 files changed

+150
-10
lines changed

sdks/go/pkg/beam/core/runtime/graphx/serialize.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,11 @@ func ms2duration(d int64) time.Duration {
208208
return time.Duration(d) * time.Millisecond
209209
}
210210

211+
// encodeFn encodes a graph.Fn into a v1pb.Fn proto message.
212+
// All string fields in the DoFn struct must be UTF-8 compliant. The vet runner
213+
// (--beam_strict) will detect any non-UTF8 strings that would fail during JSON serialization.
214+
// The check will be skipped for subtypes that implement the MarshalJSON and
215+
// UnmarshalJSON interface methods.
211216
func encodeFn(u *graph.Fn) (*v1pb.Fn, error) {
212217
switch {
213218
case u.DynFn != nil:

sdks/go/pkg/beam/runners/vet/vet.go

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ package vet
2727
import (
2828
"bytes"
2929
"context"
30+
"encoding/json"
3031
"fmt"
3132
"reflect"
3233
"strings"
@@ -91,8 +92,8 @@ func Evaluate(_ context.Context, p *beam.Pipeline) (*Eval, error) {
9192
e := newEval()
9293

9394
e.diag("/**\n")
94-
e.extractFromMultiEdges(edges)
95-
return e, nil
95+
err = e.extractFromMultiEdges(edges)
96+
return e, err
9697
}
9798

9899
func newEval() *Eval {
@@ -133,22 +134,27 @@ type Eval struct {
133134

134135
// extractFromMultiEdges audits the given pipeline edges so we can determine if
135136
// this pipeline will run without reflection.
136-
func (e *Eval) extractFromMultiEdges(edges []*graph.MultiEdge) {
137+
func (e *Eval) extractFromMultiEdges(edges []*graph.MultiEdge) error {
137138
e.diag("PTransform Audit:\n")
138139
for _, edge := range edges {
139140
switch edge.Op {
140141
case graph.ParDo:
141142
// Gets the ParDo's identifier
142143
e.diagf("pardo %s", edge.Name())
143-
e.extractGraphFn((*graph.Fn)(edge.DoFn))
144+
if err := e.extractGraphFn((*graph.Fn)(edge.DoFn)); err != nil {
145+
return err
146+
}
144147
case graph.Combine:
145148
e.diagf("combine %s", edge.Name())
146-
e.extractGraphFn((*graph.Fn)(edge.CombineFn))
149+
if err := e.extractGraphFn((*graph.Fn)(edge.CombineFn)); err != nil {
150+
return err
151+
}
147152
default:
148153
continue
149154
}
150155
e.diag("\n")
151156
}
157+
return nil
152158
}
153159

154160
// Performant returns whether this pipeline needs additional registrations
@@ -485,6 +491,73 @@ func (e *Eval) Bytes() []byte {
485491
return e.w.Bytes()
486492
}
487493

494+
// checkStructFieldsUTF8 recursively validates that all string fields in the
495+
// given value are UTF-8 compliant.
496+
// It handles structs, slices, arrays, maps, and individual strings while
497+
// avoiding infinite recursion on circular references.
498+
// The function skips validation for types that implement both json.Marshaler
499+
// and json.Unmarshaler interfaces.
500+
//
501+
// Parameters:
502+
// - v: reflect.Value to check
503+
// - seen: map tracking visited values to prevent infinite recursion
504+
//
505+
// Returns:
506+
// - error if any string field contains invalid UTF-8 encoding, nil otherwise
507+
func (e *Eval) checkStructFieldsUTF8(v reflect.Value, seen map[reflect.Value]bool) error {
508+
if !v.IsValid() || seen[v] {
509+
return nil
510+
}
511+
512+
// Track visited values to prevent infinite recursion on circular references.
513+
seen[v] = true
514+
515+
t := v.Type()
516+
517+
// Skip if type implements JSON marshaling.
518+
_, hasMarshaler := reflect.New(t).Interface().(json.Marshaler)
519+
_, hasUnmarshaler := reflect.New(t).Interface().(json.Unmarshaler)
520+
if hasMarshaler && hasUnmarshaler {
521+
return nil
522+
}
523+
524+
switch t.Kind() {
525+
case reflect.Struct:
526+
for i := 0; i < v.NumField(); i++ {
527+
field := v.Field(i)
528+
if !field.CanInterface() {
529+
// Skip unexported fields.
530+
continue
531+
}
532+
if err := e.checkStructFieldsUTF8(field, seen); err != nil {
533+
return err
534+
}
535+
}
536+
case reflect.Slice, reflect.Array:
537+
for i := 0; i < v.Len(); i++ {
538+
if err := e.checkStructFieldsUTF8(v.Index(i), seen); err != nil {
539+
return err
540+
}
541+
}
542+
case reflect.Map:
543+
iter := v.MapRange()
544+
for iter.Next() {
545+
if err := e.checkStructFieldsUTF8(iter.Key(), seen); err != nil {
546+
return err
547+
}
548+
if err := e.checkStructFieldsUTF8(iter.Value(), seen); err != nil {
549+
return err
550+
}
551+
}
552+
case reflect.String:
553+
str := v.String()
554+
if !utf8.ValidString(str) {
555+
return fmt.Errorf("non-UTF8 compliant string found: %q", str)
556+
}
557+
}
558+
return nil
559+
}
560+
488561
// We need to take graph.Fns (which can be created from any from graph.NewFn)
489562
// and convert them to all needed function caller signatures,
490563
// and emitters.
@@ -500,17 +573,29 @@ func (e *Eval) Bytes() []byte {
500573

501574
// extractGraphFn does the analysis of the function and determines what things need generating.
502575
// A single line is used, unless it's a struct, at which point one line per implemented method
503-
// is used.
504-
func (e *Eval) extractGraphFn(fn *graph.Fn) {
576+
// is used. For structs, it also validates UTF-8 compliance of all exported string fields.
577+
func (e *Eval) extractGraphFn(fn *graph.Fn) error {
505578
if fn.DynFn != nil {
506579
// TODO(https://github.com/apache/beam/issues/19401) handle dynamics if necessary (probably not since it's got general function handling)
507580
e.diag(" dynamic function")
508-
return
581+
return nil
509582
}
510583
if fn.Recv != nil {
511584
e.diagf(" struct[[%T]]", fn.Recv)
512585

513-
rt := reflectx.SkipPtr(reflect.TypeOf(fn.Recv)) // We need the value not the pointer that's used.
586+
// We need the value not the pointer that's used.
587+
rt := reflectx.SkipPtr(reflect.TypeOf(fn.Recv))
588+
rv := reflect.ValueOf(fn.Recv)
589+
if rv.Kind() == reflect.Ptr {
590+
rv = rv.Elem()
591+
}
592+
593+
// Add UTF-8 compliance check for struct fields.
594+
seen := make(map[reflect.Value]bool)
595+
if err := e.checkStructFieldsUTF8(rv, seen); err != nil {
596+
return err
597+
}
598+
514599
if tk, ok := runtime.TypeKey(rt); ok {
515600
if t, found := runtime.LookupType(tk); !found {
516601
e.needType(tk, rt)
@@ -532,6 +617,8 @@ func (e *Eval) extractGraphFn(fn *graph.Fn) {
532617
}
533618
e.extractFuncxFn(fn.Fn)
534619
}
620+
621+
return nil
535622
}
536623

537624
type mthd struct {

sdks/go/pkg/beam/runners/vet/vet_test.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,77 @@ package vet
1717

1818
import (
1919
"context"
20-
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/vet/testpipeline"
20+
"strings"
2121
"testing"
2222

23+
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/vet/testpipeline"
24+
2325
"github.com/apache/beam/sdks/v2/go/pkg/beam"
2426
)
2527

28+
type stringContentDoFn struct {
29+
Name string
30+
}
31+
32+
func (fn *stringContentDoFn) ProcessElement(ctx context.Context, _ []byte) error {
33+
return nil
34+
}
35+
36+
type errorType int
37+
38+
const (
39+
noError errorType = iota
40+
utf8Error
41+
)
42+
2643
func TestEvaluate(t *testing.T) {
2744
tests := []struct {
2845
name string
2946
c func(beam.Scope)
3047
perf, exp, ref, reg bool
48+
errType errorType
49+
errMsg string
3150
}{
3251
{name: "Performant", c: testpipeline.Performant, perf: true},
3352
{name: "FunctionReg", c: testpipeline.FunctionReg, exp: true, ref: true, reg: true},
3453
{name: "ShimNeeded", c: testpipeline.ShimNeeded, ref: true},
3554
{name: "TypeReg", c: testpipeline.TypeReg, ref: true, reg: true},
55+
{
56+
name: "NonUTF8DoFn",
57+
c: func(s beam.Scope) {
58+
fn := &stringContentDoFn{Name: "hello\xFFworld"}
59+
beam.ParDo0(s, fn, beam.Impulse(s))
60+
},
61+
errType: utf8Error,
62+
errMsg: "non-UTF8 compliant string found",
63+
},
64+
{
65+
name: "ValidUTF8DoFn",
66+
c: func(s beam.Scope) {
67+
fn := &stringContentDoFn{Name: "helloworld"}
68+
beam.ParDo0(s, fn, beam.Impulse(s))
69+
},
70+
errType: noError,
71+
ref: true,
72+
reg: true,
73+
},
3674
}
75+
3776
for _, test := range tests {
3877
test := test
3978
t.Run(test.name, func(t *testing.T) {
4079
p, s := beam.NewPipelineWithRoot()
4180
test.c(s)
4281
e, err := Evaluate(context.Background(), p)
82+
if test.errType != noError {
83+
if err == nil {
84+
t.Fatal("expected error, got nil")
85+
}
86+
if !strings.Contains(err.Error(), test.errMsg) {
87+
t.Fatalf("error %q doesn't contain %q", err.Error(), test.errMsg)
88+
}
89+
return
90+
}
4391
if err != nil {
4492
t.Fatalf("failed to evaluate testpipeline.Pipeline: %v", err)
4593
}

0 commit comments

Comments
 (0)