Skip to content

Commit ed39be1

Browse files
committed
simplify allocation: context manager returns input memrefs
1 parent 1cde5a3 commit ed39be1

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

python/examples/xegpu_matmul/matmul.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,6 @@ def _allocate_array(
9292
self.gpu_memrefs[key] = mref
9393
return mref
9494

95-
def _allocate_inputs(self, execution_engine: ExecutionEngine):
96-
self._allocate_array("A", self.a_shape, self.ab_type, execution_engine)
97-
self._allocate_array("B", self.b_shape, self.ab_type, execution_engine)
98-
self._allocate_array("C", self.c_shape, self.c_type, execution_engine)
99-
10095
def _deallocate_all(self, execution_engine: ExecutionEngine):
10196
for (_, dtype_str), mref in self.gpu_memrefs.items():
10297
dealloc_func = execution_engine.lookup("gpu_dealloc_" + dtype_str)
@@ -105,10 +100,10 @@ def _deallocate_all(self, execution_engine: ExecutionEngine):
105100
self.gpu_memrefs = {}
106101

107102
@contextmanager
108-
def allocate(self, execution_engine: ExecutionEngine):
103+
def allocate_inputs(self, execution_engine: ExecutionEngine):
109104
try:
110-
self._allocate_inputs(execution_engine)
111-
yield None
105+
inputs = self._get_input_arrays(execution_engine)
106+
yield inputs
112107
finally:
113108
self._deallocate_all(execution_engine)
114109

@@ -141,7 +136,7 @@ def _reference_solution(self) -> np.ndarray:
141136
raise NotImplementedError("Bias verification not implemented")
142137
return C_ref
143138

144-
def get_input_arrays(
139+
def _get_input_arrays(
145140
self, execution_engine: ExecutionEngine
146141
) -> list[ctypes.Structure]:
147142
A_gpu = self._allocate_array("A", self.a_shape, self.ab_type, execution_engine)

python/examples/xegpu_matmul/runner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ def execute(
7373
# get execution engine
7474
engine = get_engine(payload_module, requirements=workload.requirements())
7575

76-
with workload.allocate(execution_engine=engine):
76+
with workload.allocate_inputs(execution_engine=engine) as inputs:
7777
# prepare function arguments
78-
inputs = workload.get_input_arrays(execution_engine=engine)
7978
pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs]
8079
packed_args = get_packed_arg(pointers)
8180

@@ -150,8 +149,7 @@ def benchmark(*args):
150149
# get execution engine, rtclock requires mlir_c_runner
151150
engine = get_engine(payload_module)
152151

153-
with workload.allocate(execution_engine=engine):
154-
inputs = workload.get_input_arrays(execution_engine=engine)
152+
with workload.allocate_inputs(execution_engine=engine) as inputs:
155153
pointers = [ctypes.pointer(ctypes.pointer(m)) for m in inputs]
156154
if check_correctness:
157155
# call payload once to verify correctness

0 commit comments

Comments
 (0)