@@ -343,6 +343,60 @@ def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
343343 expected = {((1 , 2 , 3 , 4 , 5 , 6 ), 300 ): ((2 , 3 , 4 , 5 , 6 , 7 ), 350 )}
344344 self .assertEqual (actual , expected )
345345
346+ def test_map_runtime_aot_intermediate_outputs_delegated (self ):
347+ # Currently, runtime_intermediate_output logs all delegate call arguments
348+ # Test that the map function correctly extracted out the delegated outputs
349+ aot_intermediate_outputs = {
350+ (1 , 2 ): torch .tensor ([4 , 5 ]),
351+ (3 , 4 ): torch .tensor ([10 , 11 , 12 ]),
352+ (5 , 6 ): torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
353+ }
354+ runtime_intermediate_outputs = {
355+ (1 , 2 ): [torch .tensor ([1 , 2 , 3 ]), torch .tensor ([4 , 5 ])],
356+ (3 , 4 ): [
357+ torch .tensor ([6 , 7 , 8 , 9 ]),
358+ torch .tensor (1 ),
359+ torch .tensor ([10 , 11 , 12 ]),
360+ ],
361+ (5 , 6 ): [
362+ torch .tensor ([1 ]),
363+ torch .tensor ([2 ]),
364+ torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
365+ ],
366+ }
367+ actual = map_runtime_aot_intermediate_outputs (
368+ aot_intermediate_outputs , runtime_intermediate_outputs
369+ )
370+ expected = {
371+ ((1 , 2 ), torch .tensor ([4 , 5 ])): ((1 , 2 ), torch .tensor ([4 , 5 ])),
372+ ((3 , 4 ), torch .tensor ([10 , 11 , 12 ])): ((3 , 4 ), torch .tensor ([10 , 11 , 12 ])),
373+ ((5 , 6 ), torch .tensor ([13 , 14 , 15 , 16 , 17 ])): (
374+ (5 , 6 ),
375+ torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
376+ ),
377+ }
378+ self .assertEqual (len (actual ), len (expected ))
379+
380+ for (exp_aot_key , exp_aot_value ), (
381+ exp_runtime_key ,
382+ exp_runtime_value ,
383+ ) in expected .items ():
384+ found = False
385+ for (act_aot_key , act_aot_value ), (
386+ act_runtime_key ,
387+ act_runtime_value ,
388+ ) in actual .items ():
389+ if exp_aot_key == act_aot_key and torch .allclose (
390+ exp_aot_value , act_aot_value
391+ ):
392+ found = True
393+ self .assertEqual (exp_runtime_key , act_runtime_key )
394+ self .assertTrue (
395+ torch .allclose (exp_runtime_value , act_runtime_value )
396+ )
397+ break
398+ self .assertTrue (found )
399+
346400 def test_convert_input_to_tensor_convertible_inputs (self ):
347401 # Scalar -> tensor
348402 actual_output1 = convert_to_float_tensor (5 )
0 commit comments