Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions sdks/go/pkg/beam/io/bigqueryio/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"cloud.google.com/go/bigquery"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/v2/go/pkg/beam/util/structx"
Expand Down Expand Up @@ -190,13 +191,29 @@ func mustInferSchema(t reflect.Type) bigquery.Schema {
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("schema type must be struct: %v", t))
}

checkTypeRegistered(t)

schema, err := bigquery.InferSchema(reflect.Zero(t).Interface())
if err != nil {
panic(errors.Wrapf(err, "invalid schema type: %v", t))
}
return schema
}

func checkTypeRegistered(t reflect.Type) {
t = reflectx.SkipPtr(t)
key, ok := runtime.TypeKey(t)
if !ok {
panic(fmt.Sprintf("type %v must be a named type (not anonymous) for registration", t))
}

if _, registered := runtime.LookupType(key); !registered {
panic(fmt.Sprintf("type %v is not registered. Ensure that beam.RegisterType(%v) "+
"is called before beam.Init().", t, t))
}
}

func mustParseTable(table string) QualifiedTableName {
qn, err := NewQualifiedTableName(table)
if err != nil {
Expand Down
72 changes: 72 additions & 0 deletions sdks/go/pkg/beam/io/bigqueryio/bigquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ package bigqueryio
import (
"reflect"
"testing"

"cloud.google.com/go/bigquery"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
)

func TestNewQualifiedTableName(t *testing.T) {
Expand Down Expand Up @@ -76,3 +79,72 @@ func Test_constructSelectStatementPanic(t *testing.T) {
constructSelectStatement(typ, tagKey, table)
})
}

func Test_mustInferSchema(t *testing.T) {
type TestSchema struct {
Name bigquery.NullString `bigquery:"name"`
Active bigquery.NullBool `bigquery:"active"`
Score bigquery.NullFloat64 `bigquery:"score"`
Time bigquery.NullDateTime `bigquery:"time"`
}

tests := []struct {
name string
input interface{}
wantErr bool
prep func(reflect.Type) error
verify func(reflect.Type) error
}{
{
name: "NotRegisteredType_ShouldPanic",
input: TestSchema{},
wantErr: true,
prep: func(t reflect.Type) error { return nil },
verify: func(t reflect.Type) error { return nil },
},
{
name: "AlreadyRegisteredType_ShouldNotPanic",
input: TestSchema{},
wantErr: false,
prep: func(t reflect.Type) error {
beam.RegisterType(t)
return nil
},
verify: func(t reflect.Type) error {
mustInferSchema(t)
return nil
},
},
{
name: "AnonymousStruct_ShouldPanic",
input: struct{}{},
wantErr: true,
prep: func(t reflect.Type) error { return nil },
verify: func(t reflect.Type) error { return nil },
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
r := recover()
if (r != nil) != tt.wantErr {
t.Errorf("mustInferSchema() panic = %v, wantErr %v", r, tt.wantErr)
}
}()

typ := reflect.TypeOf(tt.input)
if err := tt.prep(typ); err != nil {
t.Fatalf("failed to prep test environment, got err: %v", err)
}

mustInferSchema(typ)
if tt.wantErr {
t.Fatal("Expected panic did not occur")
}
if err := tt.verify(typ); err != nil {
t.Fatal(err)
}
})
}
}
Loading