@@ -340,6 +340,11 @@ def call_function(
340
340
elif target == torch .ops .higher_order .cond :
341
341
pred , true_fn , false_fn , inputs = args
342
342
return self .callback .call_cond (pred , true_fn , false_fn , inputs , meta )
343
+ elif target == torch .ops .higher_order .while_loop :
344
+ cond , body , carried_inputs , additional_inputs = args
345
+ return self .callback .call_while (
346
+ cond , body , carried_inputs , additional_inputs , meta
347
+ )
343
348
elif target == torch .ops .higher_order .map_impl :
344
349
f , mapped_args , operands = args # type: ignore[assignment]
345
350
return self .callback .call_map (f , mapped_args , operands , meta )
@@ -497,6 +502,31 @@ def call_cond(
497
502
meta ,
498
503
)
499
504
505
+ def call_while (
506
+ self ,
507
+ cond_fn : torch .fx .GraphModule ,
508
+ body_fn : torch .fx .GraphModule ,
509
+ carried_inputs : List [Argument ],
510
+ additional_inputs : List [Argument ],
511
+ meta : NodeMetadata ,
512
+ ) -> ProxyValue :
513
+ cond_fn = self .call_submodule (cond_fn , (* carried_inputs , * additional_inputs ))
514
+ body_fn = self .call_submodule (body_fn , (* carried_inputs , * additional_inputs ))
515
+ assert cond_fn is not None
516
+ assert body_fn is not None
517
+ return self ._fx (
518
+ "call_function" ,
519
+ torch .ops .higher_order .while_loop ,
520
+ (
521
+ cond_fn .graph_module ,
522
+ body_fn .graph_module ,
523
+ carried_inputs ,
524
+ additional_inputs ,
525
+ ),
526
+ {},
527
+ meta ,
528
+ )
529
+
500
530
def call_map (
501
531
self ,
502
532
f : torch .fx .GraphModule ,
0 commit comments