@@ -92,11 +92,6 @@ def _allocate_array(
9292 self .gpu_memrefs [key ] = mref
9393 return mref
9494
95- def _allocate_inputs (self , execution_engine : ExecutionEngine ):
96- self ._allocate_array ("A" , self .a_shape , self .ab_type , execution_engine )
97- self ._allocate_array ("B" , self .b_shape , self .ab_type , execution_engine )
98- self ._allocate_array ("C" , self .c_shape , self .c_type , execution_engine )
99-
10095 def _deallocate_all (self , execution_engine : ExecutionEngine ):
10196 for (_ , dtype_str ), mref in self .gpu_memrefs .items ():
10297 dealloc_func = execution_engine .lookup ("gpu_dealloc_" + dtype_str )
@@ -105,10 +100,10 @@ def _deallocate_all(self, execution_engine: ExecutionEngine):
105100 self .gpu_memrefs = {}
106101
107102 @contextmanager
108- def allocate (self , execution_engine : ExecutionEngine ):
103+ def allocate_inputs (self , execution_engine : ExecutionEngine ):
109104 try :
110- self ._allocate_inputs (execution_engine )
111- yield None
105+ inputs = self ._get_input_arrays (execution_engine )
106+ yield inputs
112107 finally :
113108 self ._deallocate_all (execution_engine )
114109
@@ -141,7 +136,7 @@ def _reference_solution(self) -> np.ndarray:
141136 raise NotImplementedError ("Bias verification not implemented" )
142137 return C_ref
143138
144- def get_input_arrays (
139+ def _get_input_arrays (
145140 self , execution_engine : ExecutionEngine
146141 ) -> list [ctypes .Structure ]:
147142 A_gpu = self ._allocate_array ("A" , self .a_shape , self .ab_type , execution_engine )
0 commit comments