@@ -343,6 +343,60 @@ def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
343
343
expected = {((1 , 2 , 3 , 4 , 5 , 6 ), 300 ): ((2 , 3 , 4 , 5 , 6 , 7 ), 350 )}
344
344
self .assertEqual (actual , expected )
345
345
346
+ def test_map_runtime_aot_intermediate_outputs_delegated (self ):
347
+ # Currently, runtime_intermediate_output logs all delegate call arguments
348
+ # Test that the map function correctly extracted out the delegated outputs
349
+ aot_intermediate_outputs = {
350
+ (1 , 2 ): torch .tensor ([4 , 5 ]),
351
+ (3 , 4 ): torch .tensor ([10 , 11 , 12 ]),
352
+ (5 , 6 ): torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
353
+ }
354
+ runtime_intermediate_outputs = {
355
+ (1 , 2 ): [torch .tensor ([1 , 2 , 3 ]), torch .tensor ([4 , 5 ])],
356
+ (3 , 4 ): [
357
+ torch .tensor ([6 , 7 , 8 , 9 ]),
358
+ torch .tensor (1 ),
359
+ torch .tensor ([10 , 11 , 12 ]),
360
+ ],
361
+ (5 , 6 ): [
362
+ torch .tensor ([1 ]),
363
+ torch .tensor ([2 ]),
364
+ torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
365
+ ],
366
+ }
367
+ actual = map_runtime_aot_intermediate_outputs (
368
+ aot_intermediate_outputs , runtime_intermediate_outputs
369
+ )
370
+ expected = {
371
+ ((1 , 2 ), torch .tensor ([4 , 5 ])): ((1 , 2 ), torch .tensor ([4 , 5 ])),
372
+ ((3 , 4 ), torch .tensor ([10 , 11 , 12 ])): ((3 , 4 ), torch .tensor ([10 , 11 , 12 ])),
373
+ ((5 , 6 ), torch .tensor ([13 , 14 , 15 , 16 , 17 ])): (
374
+ (5 , 6 ),
375
+ torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
376
+ ),
377
+ }
378
+ self .assertEqual (len (actual ), len (expected ))
379
+
380
+ for (exp_aot_key , exp_aot_value ), (
381
+ exp_runtime_key ,
382
+ exp_runtime_value ,
383
+ ) in expected .items ():
384
+ found = False
385
+ for (act_aot_key , act_aot_value ), (
386
+ act_runtime_key ,
387
+ act_runtime_value ,
388
+ ) in actual .items ():
389
+ if exp_aot_key == act_aot_key and torch .allclose (
390
+ exp_aot_value , act_aot_value
391
+ ):
392
+ found = True
393
+ self .assertEqual (exp_runtime_key , act_runtime_key )
394
+ self .assertTrue (
395
+ torch .allclose (exp_runtime_value , act_runtime_value )
396
+ )
397
+ break
398
+ self .assertTrue (found )
399
+
346
400
def test_convert_input_to_tensor_convertible_inputs (self ):
347
401
# Scalar -> tensor
348
402
actual_output1 = convert_to_float_tensor (5 )
0 commit comments