Skip to content

Commit 4e8767c

Browse files
committed
add snake_case python binding wrappers
1 parent 5a5de34 commit 4e8767c

File tree

2 files changed

+77
-7
lines changed

2 files changed

+77
-7
lines changed

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ def __init__(
4242
)
4343

4444

45+
def get_desc_op(
46+
target: Value,
47+
*,
48+
loc=None,
49+
ip=None,
50+
) -> GetDescOp:
51+
return GetDescOp(target, loc=loc, ip=ip)
52+
53+
4554
@_ods_cext.register_operation(_Dialect, replace=True)
4655
class SetDescLayoutOp(SetDescLayoutOp):
4756
"""Specialization for SetDescLayoutOp class."""
@@ -88,6 +97,25 @@ def __init__(
8897
)
8998

9099

100+
def set_desc_layout(
101+
target: Union[Operation, Value],
102+
sg_layout: MixedValues,
103+
sg_data: MixedValues,
104+
*,
105+
inst_data: Optional[MixedValues] = None,
106+
loc=None,
107+
ip=None,
108+
) -> SetDescLayoutOp:
109+
return SetDescLayoutOp(
110+
target,
111+
sg_layout,
112+
sg_data,
113+
inst_data=inst_data,
114+
loc=loc,
115+
ip=ip,
116+
)
117+
118+
91119
@_ods_cext.register_operation(_Dialect, replace=True)
92120
class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
93121
"""Specialization for SetOpLayoutAttrOp class."""
@@ -135,6 +163,29 @@ def __init__(
135163
)
136164

137165

166+
def set_op_layout_attr(
167+
target: Union[Operation, Value],
168+
sg_layout: MixedValues,
169+
sg_data: MixedValues,
170+
*,
171+
inst_data: Optional[MixedValues] = None,
172+
index: Optional[Union[int, Attribute]] = None,
173+
result: Optional[Union[bool, Attribute]] = None,
174+
loc=None,
175+
ip=None,
176+
) -> SetOpLayoutAttrOp:
177+
return SetOpLayoutAttrOp(
178+
target,
179+
sg_layout,
180+
sg_data,
181+
inst_data=inst_data,
182+
index=index,
183+
result=result,
184+
loc=loc,
185+
ip=ip,
186+
)
187+
188+
138189
@_ods_cext.register_operation(_Dialect, replace=True)
139190
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
140191
"""Specialization for SetGPULaunchThreadsOp class."""
@@ -254,3 +305,22 @@ def __init__(
254305
loc=loc,
255306
ip=ip,
256307
)
308+
309+
310+
def convert_layout(
311+
target: Value,
312+
sg_layout: MixedValues,
313+
sg_data: MixedValues,
314+
*,
315+
inst_data: Optional[MixedValues] = None,
316+
loc=None,
317+
ip=None,
318+
) -> ConvertLayoutOp:
319+
return ConvertLayoutOp(
320+
target,
321+
sg_layout,
322+
sg_data,
323+
inst_data=inst_data,
324+
loc=loc,
325+
ip=ip,
326+
)

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def getDescOpDefaultIndex():
2525
)
2626
with InsertionPoint(sequence.body):
2727
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
28-
desc_handle = xegpu.GetDescOp(operand)
28+
desc_handle = xegpu.get_desc_op(operand)
2929
transform.YieldOp()
3030
# CHECK-LABEL: TEST: getDescOpDefaultIndex
3131
# CHECK: transform.xegpu.get_desc_op %
@@ -39,7 +39,7 @@ def setDescLayoutMinimal():
3939
transform.OperationType.get("xegpu.create_nd_tdesc"),
4040
)
4141
with InsertionPoint(sequence.body):
42-
xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
42+
xegpu.set_desc_layout(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
4343
transform.YieldOp()
4444
# CHECK-LABEL: TEST: setDescLayoutMinimal
4545
# CHECK: %0 = transform.xegpu.set_desc_layout %
@@ -55,7 +55,7 @@ def setDescLayoutInstData():
5555
transform.OperationType.get("xegpu.create_nd_tdesc"),
5656
)
5757
with InsertionPoint(sequence.body):
58-
xegpu.SetDescLayoutOp(
58+
xegpu.set_desc_layout(
5959
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
6060
)
6161
transform.YieldOp()
@@ -74,7 +74,7 @@ def setOpLayoutAttrOperandMinimal():
7474
transform.OperationType.get("xegpu.dpas"),
7575
)
7676
with InsertionPoint(sequence.body):
77-
xegpu.SetOpLayoutAttrOp(
77+
xegpu.set_op_layout_attr(
7878
sequence.bodyTarget,
7979
sg_layout=[6, 4],
8080
sg_data=[32, 16],
@@ -97,7 +97,7 @@ def setOpLayoutAttrResult():
9797
transform.OperationType.get("xegpu.dpas"),
9898
)
9999
with InsertionPoint(sequence.body):
100-
xegpu.SetOpLayoutAttrOp(
100+
xegpu.set_op_layout_attr(
101101
sequence.bodyTarget,
102102
index=0,
103103
sg_layout=[6, 4],
@@ -204,7 +204,7 @@ def ConvertLayoutMinimal():
204204
)
205205
with InsertionPoint(sequence.body):
206206
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
207-
xegpu.ConvertLayoutOp(
207+
xegpu.convert_layout(
208208
operand,
209209
sg_layout=[6, 4],
210210
sg_data=[32, 16],
@@ -225,7 +225,7 @@ def ConvertLayout():
225225
)
226226
with InsertionPoint(sequence.body):
227227
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
228-
xegpu.ConvertLayoutOp(
228+
xegpu.convert_layout(
229229
operand,
230230
sg_layout=[6, 4],
231231
sg_data=[32, 16],

0 commit comments

Comments
 (0)