@@ -302,67 +302,46 @@ def test_map_runtime_aot_intermediate_outputs_single_element_tuple(self):
302302 }
303303 self .assertEqual (actual , expected )
304304
305- def test_map_runtime_aot_intermediate_outputs_exact_match (self ):
306- # Exact match between aot and runtime debug_handles
307- aot_intermediate_outputs = {(0 , 1 ): 100 , (2 , 3 ): 200 , (4 , 5 ): 300 }
308- runtime_intermediate_outputs = {(0 , 1 ): 150 , (2 , 3 ): 200 , (4 , 5 ): 300 }
309- actual = map_runtime_aot_intermediate_outputs (
310- aot_intermediate_outputs , runtime_intermediate_outputs
311- )
312- expected = {
313- ((0 , 1 ), 100 ): ((0 , 1 ), 150 ),
314- ((2 , 3 ), 200 ): ((2 , 3 ), 200 ),
315- ((4 , 5 ), 300 ): ((4 , 5 ), 300 ),
316- }
317- self .assertEqual (actual , expected )
318-
319305 def test_map_runtime_aot_intermediate_outputs_no_overlaps (self ):
320306 # No overlaps between aot and runtime debug_handles
321- aot_intermediate_outputs = {(0 , 1 ): 100 , (4 , 5 ): 300 }
307+ aot_intermediate_outputs = {(0 ,): 100 , (4 ,): 300 }
322308 runtime_intermediate_outputs = {(2 , 3 ): 200 , (8 , 9 ): 300 }
323309 actual = map_runtime_aot_intermediate_outputs (
324310 aot_intermediate_outputs , runtime_intermediate_outputs
325311 )
326312 expected = {}
327313 self .assertEqual (actual , expected )
328314
329- def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime (self ):
330- # Multiple aot debug_handles map to one runtime debug_handle
331- aot_intermediate_outputs = {(0 , 1 , 2 ): 100 , (3 , 4 ): 300 }
332- runtime_intermediate_outputs = {(1 , 2 , 3 ): 250 , (8 , 9 ): 300 }
333- actual = map_runtime_aot_intermediate_outputs (
334- aot_intermediate_outputs , runtime_intermediate_outputs
335- )
336- expected = {((0 , 1 , 2 , 3 , 4 ), 300 ): ((1 , 2 , 3 ), 250 )}
337- self .assertEqual (actual , expected )
338-
339- def test_map_runtime_aot_intermediate_outputs_one_aot_to_multiple_runtime (self ):
340- # One aot debug_handle map to multiple runtime debug_handles
341- aot_intermediate_outputs = {(0 , 1 , 2 , 3 , 4 ): 100 , (8 , 9 ): 300 }
342- runtime_intermediate_outputs = {(0 , 1 ): 150 , (2 , 3 ): 200 , (4 , 5 ): 300 }
315+ def test_map_runtime_aot_intermediate_outputs_partial_match (self ):
316+ # Partial match between aot and runtime debug_handles will return empty
317+ aot_intermediate_outputs = {(2 ,): 100 , (9 ,): 300 }
318+ runtime_intermediate_outputs = {(2 , 3 ): 200 , (8 , 9 ): 300 }
343319 actual = map_runtime_aot_intermediate_outputs (
344320 aot_intermediate_outputs , runtime_intermediate_outputs
345321 )
346- expected = {(( 0 , 1 , 2 , 3 , 4 ), 100 ): (( 0 , 1 , 2 , 3 , 4 , 5 ), 300 ) }
322+ expected = {}
347323 self .assertEqual (actual , expected )
348324
349- def test_map_runtime_aot_intermediate_outputs_complex_chain (self ):
350- # Complex chain (N-to-N mapping)
351- aot_intermediate_outputs = {(1 , 2 ): 100 , (3 , 4 ): 200 , (5 , 6 ): 300 }
352- runtime_intermediate_outputs = {(2 , 3 ): 150 , ( 4 , 5 ): 250 , (6 , 7 ): 350 }
325+ def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime (self ):
326+ # Multiple aot debug_handles map to one runtime debug_handle
327+ aot_intermediate_outputs = {(0 , ): 100 , (1 , ): 200 , (2 , ): 300 , ( 3 ,): 400 }
328+ runtime_intermediate_outputs = {(2 , 3 , 1 ): 250 , (8 , 9 ): 300 }
353329 actual = map_runtime_aot_intermediate_outputs (
354330 aot_intermediate_outputs , runtime_intermediate_outputs
355331 )
356- expected = {((1 , 2 , 3 , 4 , 5 , 6 ), 300 ): ((2 , 3 , 4 , 5 , 6 , 7 ), 350 )}
332+ expected = {((2 , 3 , 1 ), 200 ): ((2 , 3 , 1 ), 250 )}
357333 self .assertEqual (actual , expected )
358334
359335 def test_map_runtime_aot_intermediate_outputs_delegated (self ):
360336 # Currently, runtime_intermediate_output logs all delegate call arguments
361337 # Test that the map function correctly extracted out the delegated outputs
362338 aot_intermediate_outputs = {
363- (1 , 2 ): torch .tensor ([4 , 5 ]),
364- (3 , 4 ): torch .tensor ([10 , 11 , 12 ]),
365- (5 , 6 ): torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
339+ (1 ,): torch .tensor ([4 , 1 ]),
340+ (2 ,): torch .tensor ([4 , 5 ]),
341+ (3 ,): torch .tensor ([10 , 10 , 13 ]),
342+ (4 ,): torch .tensor ([10 , 11 , 12 ]),
343+ (5 ,): torch .tensor ([13 , 14 , 15 , 16 , 21 ]),
344+ (6 ,): torch .tensor ([13 , 14 , 15 , 16 , 17 ]),
366345 }
367346 runtime_intermediate_outputs = {
368347 (1 , 2 ): [torch .tensor ([1 , 2 , 3 ]), torch .tensor ([4 , 5 ])],
0 commit comments