Skip to content

Commit 79ea2e8

Browse files
Add support for sharding while avro write (#36933)
1 parent 81bb506 commit 79ea2e8

File tree

2 files changed

+212
-14
lines changed

2 files changed

+212
-14
lines changed

sdks/go/pkg/beam/io/avroio/avroio.go

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package avroio
1919
import (
2020
"context"
2121
"encoding/json"
22+
"fmt"
23+
"math/rand"
2224
"reflect"
2325

2426
"github.com/apache/beam/sdks/v2/go/pkg/beam"
@@ -32,7 +34,10 @@ import (
3234
func init() {
3335
register.DoFn3x1[context.Context, fileio.ReadableFile, func(beam.X), error]((*avroReadFn)(nil))
3436
register.DoFn3x1[context.Context, int, func(*string) bool, error]((*writeAvroFn)(nil))
37+
register.DoFn2x0[string, func(int, string)]((*roundRobinKeyFn)(nil))
3538
register.Emitter1[beam.X]()
39+
register.Emitter1[string]()
40+
register.Emitter2[int, string]()
3641
register.Iter1[string]()
3742
}
3843

@@ -109,32 +114,121 @@ func (f *avroReadFn) ProcessElement(ctx context.Context, file fileio.ReadableFil
109114
return ar.Err()
110115
}
111116

117+
type WriteOption func(*writeConfig)
118+
119+
type writeConfig struct {
120+
suffix string
121+
numShards int
122+
}
123+
124+
// WithSuffix sets the file suffix (default: ".avro")
125+
func WithSuffix(suffix string) WriteOption {
126+
return func(c *writeConfig) {
127+
c.suffix = suffix
128+
}
129+
}
130+
131+
// WithNumShards sets the number of output shards (default: 1)
132+
func WithNumShards(numShards int) WriteOption {
133+
return func(c *writeConfig) {
134+
c.numShards = numShards
135+
}
136+
}
137+
112138
// Write writes a PCollection<string> to an AVRO file.
113139
// Write expects a JSON string with a matching AVRO schema.
114140
// the process will fail if the schema does not match the JSON
115141
// provided
116-
func Write(s beam.Scope, filename, schema string, col beam.PCollection) {
117-
s = s.Scope("avroio.Write")
118-
filesystem.ValidateScheme(filename)
119-
pre := beam.AddFixedKey(s, col)
120-
post := beam.GroupByKey(s, pre)
121-
beam.ParDo0(s, &writeAvroFn{Schema: schema, Filename: filename}, post)
142+
//
143+
// Parameters:
144+
//
145+
// prefix: File path prefix (e.g., "gs://bucket/output")
146+
// suffix: File extension (e.g., ".avro")
147+
// numShards: Number of output files (0 or 1 for single file)
148+
// schema: AVRO schema as JSON string
149+
//
150+
// Files are named as: <prefix>-<shard>-of-<numShards><suffix>
151+
// Example: output-00000-of-00010.avro
152+
//
153+
// Examples:
154+
//
155+
// Write(s, "gs://bucket/output", schema, col) // output-00000-of-00001.avro (defaults)
156+
// Write(s, "gs://bucket/output", schema, col, WithSuffix(".avro")) // output-00000-of-00001.avro (explicit)
157+
// Write(s, "gs://bucket/output", schema, col, WithNumShards(10)) // output-00000-of-00010.avro (10 shards)
158+
// Write(s, "gs://bucket/output", schema, col, WithSuffix(".avro"), WithNumShards(10)) // full control
159+
func Write(s beam.Scope, prefix, schema string, col beam.PCollection, opts ...WriteOption) {
160+
s = s.Scope("avroio.WriteSharded")
161+
filesystem.ValidateScheme(prefix)
162+
163+
config := &writeConfig{
164+
suffix: ".avro",
165+
numShards: 1,
166+
}
167+
168+
for _, opt := range opts {
169+
opt(config)
170+
}
171+
172+
// Default to single shard if not specified or 0
173+
if config.numShards <= 0 {
174+
config.numShards = 1
175+
}
176+
177+
keyed := beam.ParDo(s, &roundRobinKeyFn{NumShards: config.numShards}, col)
178+
179+
grouped := beam.GroupByKey(s, keyed)
180+
181+
beam.ParDo0(s, &writeAvroFn{
182+
Prefix: prefix,
183+
NumShards: config.numShards,
184+
Suffix: config.suffix,
185+
Schema: schema,
186+
}, grouped)
187+
}
188+
189+
type roundRobinKeyFn struct {
190+
NumShards int `json:"num_shards"`
191+
counter int
192+
initialized bool
193+
}
194+
195+
func (f *roundRobinKeyFn) StartBundle(emit func(int, string)) {
196+
f.initialized = false
197+
}
198+
199+
func (f *roundRobinKeyFn) ProcessElement(element string, emit func(int, string)) {
200+
if !f.initialized {
201+
f.counter = rand.Intn(f.NumShards)
202+
f.initialized = true
203+
}
204+
emit(f.counter, element)
205+
f.counter = (f.counter + 1) % f.NumShards
206+
}
207+
208+
// formatShardName creates filename: prefix-SSSSS-of-NNNNN.suffix
209+
func formatShardName(prefix, suffix string, shardNum, numShards int) string {
210+
width := max(len(fmt.Sprintf("%d", numShards-1)), 5)
211+
return fmt.Sprintf("%s-%0*d-of-%0*d%s", prefix, width, shardNum, width, numShards, suffix)
122212
}
123213

124214
type writeAvroFn struct {
125-
Schema string `json:"schema"`
126-
Filename string `json:"filename"`
215+
Prefix string `json:"prefix"`
216+
Suffix string `json:"suffix"`
217+
NumShards int `json:"num_shards"`
218+
Schema string `json:"schema"`
127219
}
128220

129-
func (w *writeAvroFn) ProcessElement(ctx context.Context, _ int, lines func(*string) bool) (err error) {
130-
log.Infof(ctx, "writing AVRO to %s", w.Filename)
131-
fs, err := filesystem.New(ctx, w.Filename)
221+
func (w *writeAvroFn) ProcessElement(ctx context.Context, shardNum int, lines func(*string) bool) (err error) {
222+
filename := formatShardName(w.Prefix, w.Suffix, shardNum, w.NumShards)
223+
log.Infof(ctx, "Writing AVRO shard %d/%d to %s", shardNum+1, w.NumShards, filename)
224+
225+
fs, err := filesystem.New(ctx, filename)
132226
if err != nil {
133227
return
134228
}
135229
defer fs.Close()
136230

137-
fd, err := fs.OpenWrite(ctx, w.Filename)
231+
fd, err := fs.OpenWrite(ctx, filename)
138232
if err != nil {
139233
return
140234
}

sdks/go/pkg/beam/io/avroio/avroio_test.go

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ import (
1919
"bytes"
2020
"encoding/json"
2121
"errors"
22+
"fmt"
2223
"os"
24+
"path/filepath"
2325
"reflect"
2426
"testing"
2527

@@ -141,15 +143,29 @@ const userSchema = `{
141143
}`
142144

143145
func TestWrite(t *testing.T) {
144-
avroFile := "./user.avro"
146+
testWriteDefaults(t)
147+
}
148+
149+
func TestWriteWithOptions(t *testing.T) {
150+
testWriteWithOptions(t, 3)
151+
}
152+
153+
func testWriteDefaults(t *testing.T) {
154+
avroPrefix := "./user"
155+
numShards := 1
156+
avroSuffix := ".avro"
145157
testUsername := "user1"
146158
testInfo := "userInfo"
159+
147160
p, s, sequence := ptest.CreateList([]TwitterUser{{
148161
User: testUsername,
149162
Info: testInfo,
150163
}})
151164
format := beam.ParDo(s, toJSONString, sequence)
152-
Write(s, avroFile, userSchema, format)
165+
166+
Write(s, avroPrefix, userSchema, format)
167+
168+
avroFile := fmt.Sprintf("%s-%05d-of-%05d%s", avroPrefix, 0, numShards, avroSuffix)
153169
t.Cleanup(func() {
154170
os.Remove(avroFile)
155171
})
@@ -189,3 +205,91 @@ func TestWrite(t *testing.T) {
189205
t.Fatalf("User.User=%v, want %v", got, want)
190206
}
191207
}
208+
209+
func testWriteWithOptions(t *testing.T, numShards int) {
210+
avroPrefix := "./users"
211+
avroSuffix := ".avro"
212+
users := []TwitterUser{
213+
{User: "user1", Info: "info1"},
214+
{User: "user2", Info: "info2"},
215+
{User: "user3", Info: "info3"},
216+
{User: "user4", Info: "info4"},
217+
{User: "user5", Info: "info5"},
218+
}
219+
220+
p, s, sequence := ptest.CreateList(users)
221+
format := beam.ParDo(s, toJSONString, sequence)
222+
223+
Write(s, avroPrefix, userSchema, format, WithNumShards(numShards))
224+
225+
t.Cleanup(func() {
226+
pattern := fmt.Sprintf("%s-*-of-%s%s", avroPrefix, fmt.Sprintf("%05d", numShards), avroSuffix)
227+
files, err := filepath.Glob(pattern)
228+
if err == nil {
229+
for _, f := range files {
230+
os.Remove(f)
231+
}
232+
}
233+
})
234+
235+
ptest.RunAndValidate(t, p)
236+
237+
var allRecords []map[string]any
238+
recordCounts := make(map[int]int)
239+
240+
for shardNum := 0; shardNum < numShards; shardNum++ {
241+
avroFile := fmt.Sprintf("%s-%05d-of-%05d%s", avroPrefix, shardNum, numShards, avroSuffix)
242+
243+
if _, err := os.Stat(avroFile); errors.Is(err, os.ErrNotExist) {
244+
continue
245+
}
246+
247+
avroBytes, err := os.ReadFile(avroFile)
248+
if err != nil {
249+
t.Fatalf("Failed to read avro file %v: %v", avroFile, err)
250+
}
251+
ocf, err := goavro.NewOCFReader(bytes.NewReader(avroBytes))
252+
if err != nil {
253+
t.Fatalf("Failed to make OCF Reader for %v: %v", avroFile, err)
254+
}
255+
shardRecordCount := 0
256+
for ocf.Scan() {
257+
datum, err := ocf.Read()
258+
if err != nil {
259+
break
260+
}
261+
allRecords = append(allRecords, datum.(map[string]any))
262+
shardRecordCount++
263+
}
264+
265+
recordCounts[shardNum] = shardRecordCount
266+
267+
if err := ocf.Err(); err != nil {
268+
t.Fatalf("Error decoding avro data from %v: %v", avroFile, err)
269+
}
270+
}
271+
272+
if got, want := len(allRecords), len(users); got != want {
273+
t.Fatalf("Total records across all shards, got %v, want %v", got, want)
274+
}
275+
276+
hasRecords := false
277+
for _, count := range recordCounts {
278+
if count > 0 {
279+
hasRecords = true
280+
}
281+
}
282+
if !hasRecords {
283+
t.Fatal("No records found in any shard")
284+
}
285+
foundUsers := make(map[string]bool)
286+
for _, record := range allRecords {
287+
username := record["username"].(string)
288+
foundUsers[username] = true
289+
}
290+
for _, user := range users {
291+
if !foundUsers[user.User] {
292+
t.Fatalf("Expected user %v not found in any shard", user.User)
293+
}
294+
}
295+
}

0 commit comments

Comments
 (0)