diff --git a/sdks/go/pkg/beam/io/avroio/avroio.go b/sdks/go/pkg/beam/io/avroio/avroio.go index 809c9479f7a4..3a116a74f557 100644 --- a/sdks/go/pkg/beam/io/avroio/avroio.go +++ b/sdks/go/pkg/beam/io/avroio/avroio.go @@ -19,6 +19,8 @@ package avroio import ( "context" "encoding/json" + "fmt" + "math/rand" "reflect" "github.com/apache/beam/sdks/v2/go/pkg/beam" @@ -32,7 +34,10 @@ import ( func init() { register.DoFn3x1[context.Context, fileio.ReadableFile, func(beam.X), error]((*avroReadFn)(nil)) register.DoFn3x1[context.Context, int, func(*string) bool, error]((*writeAvroFn)(nil)) + register.DoFn2x0[string, func(int, string)]((*roundRobinKeyFn)(nil)) register.Emitter1[beam.X]() + register.Emitter1[string]() + register.Emitter2[int, string]() register.Iter1[string]() } @@ -109,32 +114,121 @@ func (f *avroReadFn) ProcessElement(ctx context.Context, file fileio.ReadableFil return ar.Err() } +type WriteOption func(*writeConfig) + +type writeConfig struct { + suffix string + numShards int +} + +// WithSuffix sets the file suffix (default: ".avro") +func WithSuffix(suffix string) WriteOption { + return func(c *writeConfig) { + c.suffix = suffix + } +} + +// WithNumShards sets the number of output shards (default: 1) +func WithNumShards(numShards int) WriteOption { + return func(c *writeConfig) { + c.numShards = numShards + } +} + // Write writes a PCollection to an AVRO file. // Write expects a JSON string with a matching AVRO schema. // the process will fail if the schema does not match the JSON // provided -func Write(s beam.Scope, filename, schema string, col beam.PCollection) { - s = s.Scope("avroio.Write") - filesystem.ValidateScheme(filename) - pre := beam.AddFixedKey(s, col) - post := beam.GroupByKey(s, pre) - beam.ParDo0(s, &writeAvroFn{Schema: schema, Filename: filename}, post) +// +// Parameters: +// +// prefix: File path prefix (e.g., "gs://bucket/output") +// suffix: File extension (e.g., ".avro") +// numShards: Number of output files (0 or 1 for single file) +// schema: AVRO schema as JSON string +// +// Files are named as: --of- +// Example: output-00000-of-00010.avro +// +// Examples: +// +// Write(s, "gs://bucket/output", schema, col) // output-00000-of-00001.avro (defaults) +// Write(s, "gs://bucket/output", schema, col, WithSuffix(".avro")) // output-00000-of-00001.avro (explicit) +// Write(s, "gs://bucket/output", schema, col, WithNumShards(10)) // output-00000-of-00010.avro (10 shards) +// Write(s, "gs://bucket/output", schema, col, WithSuffix(".avro"), WithNumShards(10)) // full control +func Write(s beam.Scope, prefix, schema string, col beam.PCollection, opts ...WriteOption) { + s = s.Scope("avroio.WriteSharded") + filesystem.ValidateScheme(prefix) + + config := &writeConfig{ + suffix: ".avro", + numShards: 1, + } + + for _, opt := range opts { + opt(config) + } + + // Default to single shard if not specified or 0 + if config.numShards <= 0 { + config.numShards = 1 + } + + keyed := beam.ParDo(s, &roundRobinKeyFn{NumShards: config.numShards}, col) + + grouped := beam.GroupByKey(s, keyed) + + beam.ParDo0(s, &writeAvroFn{ + Prefix: prefix, + NumShards: config.numShards, + Suffix: config.suffix, + Schema: schema, + }, grouped) +} + +type roundRobinKeyFn struct { + NumShards int `json:"num_shards"` + counter int + initialized bool +} + +func (f *roundRobinKeyFn) StartBundle(emit func(int, string)) { + f.initialized = false +} + +func (f *roundRobinKeyFn) ProcessElement(element string, emit func(int, string)) { + if !f.initialized { + f.counter = rand.Intn(f.NumShards) + f.initialized = true + } + emit(f.counter, element) + f.counter = (f.counter + 1) % f.NumShards +} + +// formatShardName creates filename: prefix-SSSSS-of-NNNNN.suffix +func formatShardName(prefix, suffix string, shardNum, numShards int) string { + width := max(len(fmt.Sprintf("%d", numShards-1)), 5) + return fmt.Sprintf("%s-%0*d-of-%0*d%s", prefix, width, shardNum, width, numShards, suffix) } type writeAvroFn struct { - Schema string `json:"schema"` - Filename string `json:"filename"` + Prefix string `json:"prefix"` + Suffix string `json:"suffix"` + NumShards int `json:"num_shards"` + Schema string `json:"schema"` } -func (w *writeAvroFn) ProcessElement(ctx context.Context, _ int, lines func(*string) bool) (err error) { - log.Infof(ctx, "writing AVRO to %s", w.Filename) - fs, err := filesystem.New(ctx, w.Filename) +func (w *writeAvroFn) ProcessElement(ctx context.Context, shardNum int, lines func(*string) bool) (err error) { + filename := formatShardName(w.Prefix, w.Suffix, shardNum, w.NumShards) + log.Infof(ctx, "Writing AVRO shard %d/%d to %s", shardNum+1, w.NumShards, filename) + + fs, err := filesystem.New(ctx, filename) if err != nil { return } defer fs.Close() - fd, err := fs.OpenWrite(ctx, w.Filename) + fd, err := fs.OpenWrite(ctx, filename) if err != nil { return } diff --git a/sdks/go/pkg/beam/io/avroio/avroio_test.go b/sdks/go/pkg/beam/io/avroio/avroio_test.go index 403a81875557..2e888b0e040c 100644 --- a/sdks/go/pkg/beam/io/avroio/avroio_test.go +++ b/sdks/go/pkg/beam/io/avroio/avroio_test.go @@ -19,7 +19,9 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "os" + "path/filepath" "reflect" "testing" @@ -141,15 +143,29 @@ const userSchema = `{ }` func TestWrite(t *testing.T) { - avroFile := "./user.avro" + testWriteDefaults(t) +} + +func TestWriteWithOptions(t *testing.T) { + testWriteWithOptions(t, 3) +} + +func testWriteDefaults(t *testing.T) { + avroPrefix := "./user" + numShards := 1 + avroSuffix := ".avro" testUsername := "user1" testInfo := "userInfo" + p, s, sequence := ptest.CreateList([]TwitterUser{{ User: testUsername, Info: testInfo, }}) format := beam.ParDo(s, toJSONString, sequence) - Write(s, avroFile, userSchema, format) + + Write(s, avroPrefix, userSchema, format) + + avroFile := fmt.Sprintf("%s-%05d-of-%05d%s", avroPrefix, 0, numShards, avroSuffix) t.Cleanup(func() { os.Remove(avroFile) }) @@ -189,3 +205,91 @@ func TestWrite(t *testing.T) { t.Fatalf("User.User=%v, want %v", got, want) } } + +func testWriteWithOptions(t *testing.T, numShards int) { + avroPrefix := "./users" + avroSuffix := ".avro" + users := []TwitterUser{ + {User: "user1", Info: "info1"}, + {User: "user2", Info: "info2"}, + {User: "user3", Info: "info3"}, + {User: "user4", Info: "info4"}, + {User: "user5", Info: "info5"}, + } + + p, s, sequence := ptest.CreateList(users) + format := beam.ParDo(s, toJSONString, sequence) + + Write(s, avroPrefix, userSchema, format, WithNumShards(numShards)) + + t.Cleanup(func() { + pattern := fmt.Sprintf("%s-*-of-%s%s", avroPrefix, fmt.Sprintf("%05d", numShards), avroSuffix) + files, err := filepath.Glob(pattern) + if err == nil { + for _, f := range files { + os.Remove(f) + } + } + }) + + ptest.RunAndValidate(t, p) + + var allRecords []map[string]any + recordCounts := make(map[int]int) + + for shardNum := 0; shardNum < numShards; shardNum++ { + avroFile := fmt.Sprintf("%s-%05d-of-%05d%s", avroPrefix, shardNum, numShards, avroSuffix) + + if _, err := os.Stat(avroFile); errors.Is(err, os.ErrNotExist) { + continue + } + + avroBytes, err := os.ReadFile(avroFile) + if err != nil { + t.Fatalf("Failed to read avro file %v: %v", avroFile, err) + } + ocf, err := goavro.NewOCFReader(bytes.NewReader(avroBytes)) + if err != nil { + t.Fatalf("Failed to make OCF Reader for %v: %v", avroFile, err) + } + shardRecordCount := 0 + for ocf.Scan() { + datum, err := ocf.Read() + if err != nil { + break + } + allRecords = append(allRecords, datum.(map[string]any)) + shardRecordCount++ + } + + recordCounts[shardNum] = shardRecordCount + + if err := ocf.Err(); err != nil { + t.Fatalf("Error decoding avro data from %v: %v", avroFile, err) + } + } + + if got, want := len(allRecords), len(users); got != want { + t.Fatalf("Total records across all shards, got %v, want %v", got, want) + } + + hasRecords := false + for _, count := range recordCounts { + if count > 0 { + hasRecords = true + } + } + if !hasRecords { + t.Fatal("No records found in any shard") + } + foundUsers := make(map[string]bool) + for _, record := range allRecords { + username := record["username"].(string) + foundUsers[username] = true + } + for _, user := range users { + if !foundUsers[user.User] { + t.Fatalf("Expected user %v not found in any shard", user.User) + } + } +}