@@ -90,10 +90,17 @@ def vk_out(self):
9090class  ComputeGraphGen :
9191    backend_key  =  None 
9292
93-     def  __init__ (self , op_reg_name : str , f : NativeFunction , suite_def : TestSuite ):
93+     def  __init__ (
94+         self ,
95+         op_reg_name : str ,
96+         f : NativeFunction ,
97+         suite_def : TestSuite ,
98+         include_io : bool  =  True ,
99+     ):
94100        self .op_reg_name  =  op_reg_name 
95101        self .f  =  f 
96102        self .suite_def  =  suite_def 
103+         self .include_io  =  include_io 
97104
98105        self .f_sig  =  CppSignatureGroup .from_native_function (
99106            self .f , method = False , fallback_binding = self .f .manual_cpp_binding 
@@ -275,6 +282,10 @@ def create_value_for(  # noqa: C901
275282        prepack  =  self .prepack_ref (ref )
276283        ref_is_view  =  self .suite_def .is_view_op  and  ref .is_out 
277284
285+         # If skipping IO, force is_in to be False 
286+         if  not  self .include_io  and  ref .is_in :
287+             ref .is_in  =  False 
288+ 
278289        cpp_type  =  "IOValueRef"  if  (ref .is_in  and  not  prepack ) else  "ValueRef" 
279290        if  not  include_declarations :
280291            cpp_type  =  "" 
@@ -602,7 +613,8 @@ def gen_graph_build_code(self, include_declarations: bool = True) -> str:
602613        graph_build  +=  self .create_value_for (self .refs ["out" ], include_declarations )
603614        graph_build  +=  self .create_op_call ()
604615
605-         graph_build  +=  self .set_output (self .refs ["out" ], include_declarations )
616+         if  self .include_io :
617+             graph_build  +=  self .set_output (self .refs ["out" ], include_declarations )
606618
607619        graph_build  +=  f"{ self .graph } { self .dot } \n " 
608620        graph_build  +=  f"{ self .graph } { self .dot } \n " 
@@ -614,18 +626,22 @@ def gen_graph_build_code(self, include_declarations: bool = True) -> str:
614626
615627    def  gen_graph_exec_code (self , check_output = True ) ->  str :
616628        graph_exec  =  "" 
617-         for  aten_arg  in  self .args :
618-             ref  =  self .refs [aten_arg .name ]
619-             if  ref .is_in :
620-                 graph_exec  +=  self .virtual_resize (ref )
621-                 graph_exec  +=  self .copy_into_staging (ref )
629+         if  self .include_io :
630+             for  aten_arg  in  self .args :
631+                 ref  =  self .refs [aten_arg .name ]
632+                 if  ref .is_in :
633+                     graph_exec  +=  self .virtual_resize (ref )
634+                     graph_exec  +=  self .copy_into_staging (ref )
635+ 
636+             graph_exec  +=  f"{ self .graph } { self .dot } \n " 
622637
623-         graph_exec  +=  f"{ self .graph } { self .dot } \n " 
624638        graph_exec  +=  f"{ self .graph } { self .dot } \n " 
625639
626640        graph_exec  +=  self .declare_vk_out_for (self .refs ["out" ])
627-         graph_exec  +=  self .copy_from_staging (self .refs ["out" ])
628-         if  check_output :
641+         if  self .include_io :
642+             graph_exec  +=  self .copy_from_staging (self .refs ["out" ])
643+ 
644+         if  self .include_io  and  check_output :
629645            graph_exec  +=  self .check_graph_out (self .refs ["out" ])
630646
631647        graph_exec  =  re .sub (r"^" , "  " , graph_exec , flags = re .M )
0 commit comments