Skip to content

Commit b5b9181

Browse files
authored
Move the logic to LP TestStream encoded bytes to preprocess steps. (#36465)
1 parent 243a52c commit b5b9181

File tree

3 files changed

+115
-54
lines changed

3 files changed

+115
-54
lines changed

sdks/go/pkg/beam/runners/prism/internal/coders.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,11 @@ func lpUnknownCoders(cID string, bundle, base map[string]*pipepb.Coder) (string,
199199
}
200200

201201
// forceLpCoder always add a new LP-coder for a given coder into the "base" map
202-
func forceLpCoder(cID string, base map[string]*pipepb.Coder) (string, error) {
202+
func forceLpCoder(cID string, bundle, base map[string]*pipepb.Coder) (string, error) {
203203
// First check if we've already added the LP version of this coder to coders already.
204204
lpcID := cID + "_flp"
205205
// Check if we've done this one before.
206-
if _, ok := base[lpcID]; ok {
206+
if _, ok := bundle[lpcID]; ok {
207207
return lpcID, nil
208208
}
209209
// Look up the canonical location.
@@ -219,7 +219,7 @@ func forceLpCoder(cID string, base map[string]*pipepb.Coder) (string, error) {
219219
},
220220
ComponentCoderIds: []string{cID},
221221
}
222-
base[lpcID] = lpc
222+
bundle[lpcID] = lpc
223223
return lpcID, nil
224224
}
225225

sdks/go/pkg/beam/runners/prism/internal/execute.go

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
package internal
1717

1818
import (
19-
"bytes"
2019
"context"
2120
"errors"
2221
"fmt"
@@ -27,7 +26,6 @@ import (
2726
"sync/atomic"
2827
"time"
2928

30-
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
3129
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
3230
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
3331
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
@@ -270,67 +268,21 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic
270268
case urns.TransformTestStream:
271269
// Add a synthetic stage that should largely be unused.
272270
em.AddStage(stage.ID, nil, maps.Values(t.GetOutputs()), nil)
271+
273272
// Decode the test stream, and convert it to the various events for the ElementManager.
274273
var pyld pipepb.TestStreamPayload
275274
if err := proto.Unmarshal(t.GetSpec().GetPayload(), &pyld); err != nil {
276275
return fmt.Errorf("prism error building stage %v - decoding TestStreamPayload: \n%w", stage.ID, err)
277276
}
278277

279-
// Ensure awareness of the coder used for the teststream.
280-
cID, err := lpUnknownCoders(pyld.GetCoderId(), coders, comps.GetCoders())
281-
if err != nil {
282-
panic(err)
283-
}
284-
mayLP := func(v []byte) []byte {
285-
//slog.Warn("teststream bytes", "value", string(v), "bytes", v)
286-
return v
287-
}
288-
// If the TestStream coder needs to be LP'ed or if it is a coder that has different
289-
// behaviors between nested context and outer context (in Java SDK), then we must
290-
// LP this coder and the TestStream data elements.
291-
forceLP := cID != pyld.GetCoderId() ||
292-
coders[cID].GetSpec().GetUrn() == urns.CoderStringUTF8 ||
293-
coders[cID].GetSpec().GetUrn() == urns.CoderBytes ||
294-
coders[cID].GetSpec().GetUrn() == urns.CoderKV
295-
if forceLP {
296-
// slog.Warn("recoding TestStreamValue", "cID", cID, "newUrn", coders[cID].GetSpec().GetUrn(), "payloadCoder", pyld.GetCoderId(), "oldUrn", coders[pyld.GetCoderId()].GetSpec().GetUrn())
297-
// The coder needed length prefixing. For simplicity, add a length prefix to each
298-
// encoded element, since we will be sending a length prefixed coder to consume
299-
// this anyway. This is simpler than trying to find all the re-written coders after the fact.
300-
// This also adds a LP-coder for the original coder in comps.
301-
cID, err := forceLpCoder(pyld.GetCoderId(), comps.GetCoders())
302-
if err != nil {
303-
panic(err)
304-
}
305-
slog.Debug("teststream: add coder", "coderId", cID)
306-
307-
mayLP = func(v []byte) []byte {
308-
var buf bytes.Buffer
309-
if err := coder.EncodeVarInt((int64)(len(v)), &buf); err != nil {
310-
panic(err)
311-
}
312-
if _, err := buf.Write(v); err != nil {
313-
panic(err)
314-
}
315-
//slog.Warn("teststream bytes - after LP", "value", string(v), "bytes", buf.Bytes())
316-
return buf.Bytes()
317-
}
318-
319-
// we need to change Coder and Pcollection in comps directly before they are used to build descriptors
320-
for _, col := range t.GetOutputs() {
321-
oCID := comps.Pcollections[col].CoderId
322-
comps.Pcollections[col].CoderId = cID
323-
slog.Debug("teststream: rewrite coder for output pcoll", "colId", col, "oldId", oCID, "newId", cID)
324-
}
325-
}
326-
327278
tsb := em.AddTestStream(stage.ID, t.Outputs)
328279
for _, e := range pyld.GetEvents() {
329280
switch ev := e.GetEvent().(type) {
330281
case *pipepb.TestStreamPayload_Event_ElementEvent:
331282
var elms []engine.TestStreamElement
332283
for _, e := range ev.ElementEvent.GetElements() {
333-
elms = append(elms, engine.TestStreamElement{Encoded: mayLP(e.GetEncodedElement()), EventTime: mtime.FromMilliseconds(e.GetTimestamp())})
284+
// Encoded bytes are already handled in handleTestStream if needed.
285+
elms = append(elms, engine.TestStreamElement{Encoded: e.GetEncodedElement(), EventTime: mtime.FromMilliseconds(e.GetTimestamp())})
334286
}
335287
tsb.AddElementEvent(ev.ElementEvent.GetTag(), elms)
336288
case *pipepb.TestStreamPayload_Event_WatermarkEvent:

sdks/go/pkg/beam/runners/prism/internal/handlerunner.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"bytes"
2020
"fmt"
2121
"io"
22+
"log/slog"
2223
"reflect"
2324
"sort"
2425
"strings"
@@ -72,6 +73,7 @@ func (*runner) PrepareUrns() []string {
7273
urns.TransformRedistributeArbitrarily,
7374
urns.TransformRedistributeByKey,
7475
urns.TransformFlatten,
76+
urns.TransformTestStream,
7577
}
7678
}
7779

@@ -82,6 +84,8 @@ func (h *runner) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipep
8284
return h.handleFlatten(tid, t, comps)
8385
case urns.TransformReshuffle, urns.TransformRedistributeArbitrarily, urns.TransformRedistributeByKey:
8486
return h.handleReshuffle(tid, t, comps)
87+
case urns.TransformTestStream:
88+
return h.handleTestStream(tid, t, comps)
8589
default:
8690
panic("unknown urn to Prepare: " + t.GetSpec().GetUrn())
8791
}
@@ -216,6 +220,111 @@ func (h *runner) handleReshuffle(tid string, t *pipepb.PTransform, comps *pipepb
216220
}
217221
}
218222

223+
func (h *runner) handleTestStream(tid string, t *pipepb.PTransform, comps *pipepb.Components) prepareResult {
224+
var pyld pipepb.TestStreamPayload
225+
if err := proto.Unmarshal(t.GetSpec().GetPayload(), &pyld); err != nil {
226+
panic("Failed to decode TestStreamPayload: " + err.Error())
227+
}
228+
coders := map[string]*pipepb.Coder{}
229+
// Ensure awareness of the coder used for the teststream.
230+
cID, err := lpUnknownCoders(pyld.GetCoderId(), coders, comps.GetCoders())
231+
if err != nil {
232+
panic(err)
233+
}
234+
235+
// If the TestStream coder needs to be LP'ed or if it is a coder that has different
236+
// behaviors between nested context and outer context (in Java SDK), then we must
237+
// LP this coder and the TestStream data elements.
238+
forceLP := (cID != pyld.GetCoderId() && coders[pyld.GetCoderId()].GetSpec().GetUrn() != "beam:go:coder:custom:v1") ||
239+
coders[cID].GetSpec().GetUrn() == urns.CoderStringUTF8 ||
240+
coders[cID].GetSpec().GetUrn() == urns.CoderBytes ||
241+
coders[cID].GetSpec().GetUrn() == urns.CoderKV
242+
243+
if !forceLP {
244+
return prepareResult{SubbedComps: &pipepb.Components{
245+
Transforms: map[string]*pipepb.PTransform{tid: t},
246+
}}
247+
}
248+
249+
// The coder needed length prefixing. For simplicity, add a length prefix to each
250+
// encoded element, since we will be sending a length prefixed coder to consume
251+
// this anyway. This is simpler than trying to find all the re-written coders after the fact.
252+
// This also adds a LP-coder for the original coder in comps.
253+
cID, err = forceLpCoder(pyld.GetCoderId(), coders, comps.GetCoders())
254+
if err != nil {
255+
panic(err)
256+
}
257+
slog.Debug("teststream: add coder", "coderId", cID)
258+
259+
mustLP := func(v []byte) []byte {
260+
var buf bytes.Buffer
261+
if err := coder.EncodeVarInt((int64)(len(v)), &buf); err != nil {
262+
panic(err)
263+
}
264+
if _, err := buf.Write(v); err != nil {
265+
panic(err)
266+
}
267+
return buf.Bytes()
268+
}
269+
270+
// We need to loop over the events.
271+
// For element events, we need to apply the mayLP function to the encoded element.
272+
// Then we construct a new payload with the modified events.
273+
var newEvents []*pipepb.TestStreamPayload_Event
274+
for _, event := range pyld.GetEvents() {
275+
switch event.GetEvent().(type) {
276+
case *pipepb.TestStreamPayload_Event_ElementEvent:
277+
elms := event.GetElementEvent().GetElements()
278+
var newElms []*pipepb.TestStreamPayload_TimestampedElement
279+
for _, elm := range elms {
280+
newElm := proto.Clone(elm).(*pipepb.TestStreamPayload_TimestampedElement)
281+
newElm.EncodedElement = mustLP(elm.GetEncodedElement())
282+
slog.Debug("handleTestStream: rewrite bytes",
283+
"before:", string(elm.GetEncodedElement()),
284+
"after:", string(newElm.GetEncodedElement()))
285+
newElms = append(newElms, newElm)
286+
}
287+
newEvents = append(newEvents, &pipepb.TestStreamPayload_Event{
288+
Event: &pipepb.TestStreamPayload_Event_ElementEvent{
289+
ElementEvent: &pipepb.TestStreamPayload_Event_AddElements{
290+
Elements: newElms,
291+
},
292+
},
293+
})
294+
default:
295+
newEvents = append(newEvents, event)
296+
}
297+
}
298+
newPyld := &pipepb.TestStreamPayload{
299+
CoderId: cID,
300+
Events: newEvents,
301+
Endpoint: pyld.GetEndpoint(),
302+
}
303+
b, err := proto.Marshal(newPyld)
304+
if err != nil {
305+
panic(fmt.Sprintf("couldn't marshal new test stream payload: %v", err))
306+
}
307+
308+
ts := proto.Clone(t).(*pipepb.PTransform)
309+
ts.GetSpec().Payload = b
310+
311+
pcolSubs := map[string]*pipepb.PCollection{}
312+
for _, gi := range ts.GetOutputs() {
313+
pcol := comps.GetPcollections()[gi]
314+
newPcol := proto.Clone(pcol).(*pipepb.PCollection)
315+
newPcol.CoderId = cID
316+
slog.Debug("handleTestStream: rewrite coder for output pcoll", "colId", gi, "oldId", pcol.CoderId, "newId", newPcol.CoderId)
317+
pcolSubs[gi] = newPcol
318+
}
319+
320+
tSubs := map[string]*pipepb.PTransform{tid: ts}
321+
return prepareResult{SubbedComps: &pipepb.Components{
322+
Transforms: tSubs,
323+
Pcollections: pcolSubs,
324+
Coders: coders,
325+
}}
326+
}
327+
219328
var _ transformExecuter = (*runner)(nil)
220329

221330
func (*runner) ExecuteUrns() []string {

0 commit comments

Comments
 (0)