Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 106 additions & 12 deletions sdks/go/pkg/beam/io/avroio/avroio.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package avroio
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"reflect"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
Expand All @@ -32,7 +34,10 @@ import (
func init() {
register.DoFn3x1[context.Context, fileio.ReadableFile, func(beam.X), error]((*avroReadFn)(nil))
register.DoFn3x1[context.Context, int, func(*string) bool, error]((*writeAvroFn)(nil))
register.DoFn2x0[string, func(int, string)]((*roundRobinKeyFn)(nil))
register.Emitter1[beam.X]()
register.Emitter1[string]()
register.Emitter2[int, string]()
register.Iter1[string]()
}

Expand Down Expand Up @@ -109,32 +114,121 @@ func (f *avroReadFn) ProcessElement(ctx context.Context, file fileio.ReadableFil
return ar.Err()
}

type WriteOption func(*writeConfig)

type writeConfig struct {
suffix string
numShards int
}

// WithSuffix sets the file suffix (default: ".avro")
func WithSuffix(suffix string) WriteOption {
return func(c *writeConfig) {
c.suffix = suffix
}
}

// WithNumShards sets the number of output shards (default: 1)
func WithNumShards(numShards int) WriteOption {
return func(c *writeConfig) {
c.numShards = numShards
}
}

// Write writes a PCollection<string> to an AVRO file.
// Write expects a JSON string with a matching AVRO schema.
// the process will fail if the schema does not match the JSON
// provided
func Write(s beam.Scope, filename, schema string, col beam.PCollection) {
s = s.Scope("avroio.Write")
filesystem.ValidateScheme(filename)
pre := beam.AddFixedKey(s, col)
post := beam.GroupByKey(s, pre)
beam.ParDo0(s, &writeAvroFn{Schema: schema, Filename: filename}, post)
//
// Parameters:
//
// prefix: File path prefix (e.g., "gs://bucket/output")
// suffix: File extension (e.g., ".avro")
// numShards: Number of output files (0 or 1 for single file)
// schema: AVRO schema as JSON string
//
// Files are named as: <prefix>-<shard>-of-<numShards><suffix>
// Example: output-00000-of-00010.avro
//
// Examples:
//
// Write(s, "gs://bucket/output", schema, col) // output-00000-of-00001.avro (defaults)
// Write(s, "gs://bucket/output", schema, col, WithSuffix(".avro")) // output-00000-of-00001.avro (explicit)
// Write(s, "gs://bucket/output", schema, col, WithNumShards(10)) // output-00000-of-00010.avro (10 shards)
// Write(s, "gs://bucket/output", schema, col, WithSuffix(".avro"), WithNumShards(10)) // full control
func Write(s beam.Scope, prefix, schema string, col beam.PCollection, opts ...WriteOption) {
s = s.Scope("avroio.WriteSharded")
filesystem.ValidateScheme(prefix)

config := &writeConfig{
suffix: ".avro",
numShards: 1,
}

for _, opt := range opts {
opt(config)
}

// Default to single shard if not specified or 0
if config.numShards <= 0 {
config.numShards = 1
}

keyed := beam.ParDo(s, &roundRobinKeyFn{NumShards: config.numShards}, col)

grouped := beam.GroupByKey(s, keyed)

beam.ParDo0(s, &writeAvroFn{
Prefix: prefix,
NumShards: config.numShards,
Suffix: config.suffix,
Schema: schema,
}, grouped)
}

type roundRobinKeyFn struct {
NumShards int `json:"num_shards"`
counter int
initialized bool
}

func (f *roundRobinKeyFn) StartBundle(emit func(int, string)) {
f.initialized = false
}

func (f *roundRobinKeyFn) ProcessElement(element string, emit func(int, string)) {
if !f.initialized {
f.counter = rand.Intn(f.NumShards)
f.initialized = true
}
emit(f.counter, element)
f.counter = (f.counter + 1) % f.NumShards
}

// formatShardName creates filename: prefix-SSSSS-of-NNNNN.suffix
func formatShardName(prefix, suffix string, shardNum, numShards int) string {
width := max(len(fmt.Sprintf("%d", numShards-1)), 5)
return fmt.Sprintf("%s-%0*d-of-%0*d%s", prefix, width, shardNum, width, numShards, suffix)
}

type writeAvroFn struct {
Schema string `json:"schema"`
Filename string `json:"filename"`
Prefix string `json:"prefix"`
Suffix string `json:"suffix"`
NumShards int `json:"num_shards"`
Schema string `json:"schema"`
}

func (w *writeAvroFn) ProcessElement(ctx context.Context, _ int, lines func(*string) bool) (err error) {
log.Infof(ctx, "writing AVRO to %s", w.Filename)
fs, err := filesystem.New(ctx, w.Filename)
func (w *writeAvroFn) ProcessElement(ctx context.Context, shardNum int, lines func(*string) bool) (err error) {
filename := formatShardName(w.Prefix, w.Suffix, shardNum, w.NumShards)
log.Infof(ctx, "Writing AVRO shard %d/%d to %s", shardNum+1, w.NumShards, filename)

fs, err := filesystem.New(ctx, filename)
if err != nil {
return
}
defer fs.Close()

fd, err := fs.OpenWrite(ctx, w.Filename)
fd, err := fs.OpenWrite(ctx, filename)
if err != nil {
return
}
Expand Down
108 changes: 106 additions & 2 deletions sdks/go/pkg/beam/io/avroio/avroio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"reflect"
"testing"

Expand Down Expand Up @@ -141,15 +143,29 @@ const userSchema = `{
}`

func TestWrite(t *testing.T) {
avroFile := "./user.avro"
testWriteDefaults(t)
}

func TestWriteWithOptions(t *testing.T) {
testWriteWithOptions(t, 3)
}

func testWriteDefaults(t *testing.T) {
avroPrefix := "./user"
numShards := 1
avroSuffix := ".avro"
testUsername := "user1"
testInfo := "userInfo"

p, s, sequence := ptest.CreateList([]TwitterUser{{
User: testUsername,
Info: testInfo,
}})
format := beam.ParDo(s, toJSONString, sequence)
Write(s, avroFile, userSchema, format)

Write(s, avroPrefix, userSchema, format)

avroFile := fmt.Sprintf("%s-%05d-of-%05d%s", avroPrefix, 0, numShards, avroSuffix)
t.Cleanup(func() {
os.Remove(avroFile)
})
Expand Down Expand Up @@ -189,3 +205,91 @@ func TestWrite(t *testing.T) {
t.Fatalf("User.User=%v, want %v", got, want)
}
}

func testWriteWithOptions(t *testing.T, numShards int) {
avroPrefix := "./users"
avroSuffix := ".avro"
users := []TwitterUser{
{User: "user1", Info: "info1"},
{User: "user2", Info: "info2"},
{User: "user3", Info: "info3"},
{User: "user4", Info: "info4"},
{User: "user5", Info: "info5"},
}

p, s, sequence := ptest.CreateList(users)
format := beam.ParDo(s, toJSONString, sequence)

Write(s, avroPrefix, userSchema, format, WithNumShards(numShards))

t.Cleanup(func() {
pattern := fmt.Sprintf("%s-*-of-%s%s", avroPrefix, fmt.Sprintf("%05d", numShards), avroSuffix)
files, err := filepath.Glob(pattern)
if err == nil {
for _, f := range files {
os.Remove(f)
}
}
})

ptest.RunAndValidate(t, p)

var allRecords []map[string]any
recordCounts := make(map[int]int)

for shardNum := 0; shardNum < numShards; shardNum++ {
avroFile := fmt.Sprintf("%s-%05d-of-%05d%s", avroPrefix, shardNum, numShards, avroSuffix)

if _, err := os.Stat(avroFile); errors.Is(err, os.ErrNotExist) {
continue
}

avroBytes, err := os.ReadFile(avroFile)
if err != nil {
t.Fatalf("Failed to read avro file %v: %v", avroFile, err)
}
ocf, err := goavro.NewOCFReader(bytes.NewReader(avroBytes))
if err != nil {
t.Fatalf("Failed to make OCF Reader for %v: %v", avroFile, err)
}
shardRecordCount := 0
for ocf.Scan() {
datum, err := ocf.Read()
if err != nil {
break
}
allRecords = append(allRecords, datum.(map[string]any))
shardRecordCount++
}

recordCounts[shardNum] = shardRecordCount

if err := ocf.Err(); err != nil {
t.Fatalf("Error decoding avro data from %v: %v", avroFile, err)
}
}

if got, want := len(allRecords), len(users); got != want {
t.Fatalf("Total records across all shards, got %v, want %v", got, want)
}

hasRecords := false
for _, count := range recordCounts {
if count > 0 {
hasRecords = true
}
}
if !hasRecords {
t.Fatal("No records found in any shard")
}
foundUsers := make(map[string]bool)
for _, record := range allRecords {
username := record["username"].(string)
foundUsers[username] = true
}
for _, user := range users {
if !foundUsers[user.User] {
t.Fatalf("Expected user %v not found in any shard", user.User)
}
}
}
Loading