@@ -227,44 +227,101 @@ func (h *runner) handleTestStream(tid string, t *pipepb.PTransform, comps *pipep
227227 }
228228 coders := map [string ]* pipepb.Coder {}
229229 // Ensure awareness of the coder used for the teststream.
230- cID , err := lpUnknownCoders (pyld .GetCoderId (), coders , comps .GetCoders ())
230+ ocID := pyld .GetCoderId ()
231+ cID , err := lpUnknownCoders (ocID , coders , comps .GetCoders ())
231232 if err != nil {
232233 panic (err )
233234 }
234235
235236 // If the TestStream coder needs to be LP'ed or if it is a coder that has different
236237 // behaviors between nested context and outer context (in Java SDK), then we must
237238 // LP this coder and the TestStream data elements.
238- forceLP := (cID != pyld . GetCoderId () && coders [pyld . GetCoderId () ].GetSpec ().GetUrn () != "beam:go:coder:custom:v1" ) ||
239- coders [cID ].GetSpec ().GetUrn () == urns .CoderStringUTF8 ||
240- coders [cID ].GetSpec ().GetUrn () == urns .CoderBytes ||
241- coders [cID ].GetSpec ().GetUrn () == urns .CoderKV
239+ forceLP := (cID != ocID && coders [ocID ].GetSpec ().GetUrn () != "beam:go:coder:custom:v1" ) ||
240+ coders [ocID ].GetSpec ().GetUrn () == urns .CoderStringUTF8 ||
241+ coders [ocID ].GetSpec ().GetUrn () == urns .CoderBytes ||
242+ coders [ocID ].GetSpec ().GetUrn () == urns .CoderKV
242243
243244 if ! forceLP {
244245 return prepareResult {SubbedComps : & pipepb.Components {
245246 Transforms : map [string ]* pipepb.PTransform {tid : t },
246247 }}
247248 }
248249
249- // The coder needed length prefixing. For simplicity, add a length prefix to each
250- // encoded element, since we will be sending a length prefixed coder to consume
251- // this anyway. This is simpler than trying to find all the re-written coders after the fact.
252- // This also adds a LP-coder for the original coder in comps.
253- cID , err = forceLpCoder (pyld .GetCoderId (), coders , comps .GetCoders ())
254- if err != nil {
255- panic (err )
256- }
257- slog .Debug ("teststream: add coder" , "coderId" , cID )
258-
259- mustLP := func (v []byte ) []byte {
260- var buf bytes.Buffer
261- if err := coder .EncodeVarInt ((int64 )(len (v )), & buf ); err != nil {
250+ var mustLP func (v []byte ) []byte
251+ if coders [ocID ].GetSpec ().GetUrn () != urns .CoderKV {
252+ // The coder needed length prefixing. For simplicity, add a length prefix to each
253+ // encoded element, since we will be sending a length prefixed coder to consume
254+ // this anyway. This is simpler than trying to find all the re-written coders after the fact.
255+ // This also adds a LP-coder for the original coder in comps.
256+ cID , err = forceLpCoder (pyld .GetCoderId (), coders , comps .GetCoders ())
257+ if err != nil {
262258 panic (err )
263259 }
264- if _ , err := buf .Write (v ); err != nil {
260+ slog .Debug ("teststream: add coder" , "coderId" , cID )
261+
262+ mustLP = func (v []byte ) []byte {
263+ var buf bytes.Buffer
264+ if err := coder .EncodeVarInt ((int64 )(len (v )), & buf ); err != nil {
265+ panic (err )
266+ }
267+ if _ , err := buf .Write (v ); err != nil {
268+ panic (err )
269+ }
270+ return buf .Bytes ()
271+ }
272+ } else {
273+ // For a KV coder, we only length-prefix the value coder because we need to
274+ // preserve the original structure of the key coder. This allows the key
275+ // coder to be easily extracted later to retrieve the KeyBytes from the
276+ // encoded elements.
277+
278+ c := coders [ocID ]
279+ kcid := c .GetComponentCoderIds ()[0 ]
280+ vcid := c .GetComponentCoderIds ()[1 ]
281+
282+ var lpvcid string
283+ lpvcid , err = forceLpCoder (vcid , coders , comps .GetCoders ())
284+ if err != nil {
265285 panic (err )
266286 }
267- return buf .Bytes ()
287+
288+ slog .Debug ("teststream: add coder" , "coderId" , lpvcid )
289+
290+ kvc := & pipepb.Coder {
291+ Spec : & pipepb.FunctionSpec {
292+ Urn : urns .CoderKV ,
293+ },
294+ ComponentCoderIds : []string {kcid , lpvcid },
295+ }
296+
297+ kvcID := ocID + "_vlp"
298+ coders [kvcID ] = kvc
299+
300+ slog .Debug ("teststream: add coder" , "coderId" , kvcID )
301+
302+ cID = kvcID
303+
304+ kd := collectionPullDecoder (kcid , coders , comps )
305+ mustLP = func (v []byte ) []byte {
306+ elmBuf := bytes .NewBuffer (v )
307+ keyBytes := kd (elmBuf )
308+
309+ var buf bytes.Buffer
310+ if _ , err := buf .Write (keyBytes ); err != nil {
311+ panic (err )
312+ }
313+
314+ // put the length of the value
315+ if err := coder .EncodeVarInt ((int64 )(len (v )- len (keyBytes )), & buf ); err != nil {
316+ panic (err )
317+ }
318+
319+ // write the value aka. the remaining bytes from the buffer
320+ if _ , err := buf .Write (elmBuf .Bytes ()); err != nil {
321+ panic (err )
322+ }
323+ return buf .Bytes ()
324+ }
268325 }
269326
270327 // We need to loop over the events.
0 commit comments