Skip to content

Commit 549c81b

Browse files
committed
add snake_case python binding wrappers
1 parent 0db7e67 commit 549c81b

File tree

2 files changed

+77
-7
lines changed

2 files changed

+77
-7
lines changed

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ def __init__(
4141
)
4242

4343

44+
def get_desc_op(
45+
target: Value,
46+
*,
47+
loc=None,
48+
ip=None,
49+
) -> GetDescOp:
50+
return GetDescOp(target, loc=loc, ip=ip)
51+
52+
4453
@_ods_cext.register_operation(_Dialect, replace=True)
4554
class SetDescLayoutOp(SetDescLayoutOp):
4655
"""Specialization for SetDescLayoutOp class."""
@@ -87,6 +96,25 @@ def __init__(
8796
)
8897

8998

99+
def set_desc_layout(
100+
target: Union[Operation, Value],
101+
sg_layout: MixedValues,
102+
sg_data: MixedValues,
103+
*,
104+
inst_data: Optional[MixedValues] = None,
105+
loc=None,
106+
ip=None,
107+
) -> SetDescLayoutOp:
108+
return SetDescLayoutOp(
109+
target,
110+
sg_layout,
111+
sg_data,
112+
inst_data=inst_data,
113+
loc=loc,
114+
ip=ip,
115+
)
116+
117+
90118
@_ods_cext.register_operation(_Dialect, replace=True)
91119
class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
92120
"""Specialization for SetOpLayoutAttrOp class."""
@@ -134,6 +162,29 @@ def __init__(
134162
)
135163

136164

165+
def set_op_layout_attr(
166+
target: Union[Operation, Value],
167+
sg_layout: MixedValues,
168+
sg_data: MixedValues,
169+
*,
170+
inst_data: Optional[MixedValues] = None,
171+
index: Optional[Union[int, Attribute]] = None,
172+
result: Optional[Union[bool, Attribute]] = None,
173+
loc=None,
174+
ip=None,
175+
) -> SetOpLayoutAttrOp:
176+
return SetOpLayoutAttrOp(
177+
target,
178+
sg_layout,
179+
sg_data,
180+
inst_data=inst_data,
181+
index=index,
182+
result=result,
183+
loc=loc,
184+
ip=ip,
185+
)
186+
187+
137188
@_ods_cext.register_operation(_Dialect, replace=True)
138189
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
139190
"""Specialization for SetGPULaunchThreadsOp class."""
@@ -212,3 +263,22 @@ def __init__(
212263
loc=loc,
213264
ip=ip,
214265
)
266+
267+
268+
def convert_layout(
269+
target: Value,
270+
sg_layout: MixedValues,
271+
sg_data: MixedValues,
272+
*,
273+
inst_data: Optional[MixedValues] = None,
274+
loc=None,
275+
ip=None,
276+
) -> ConvertLayoutOp:
277+
return ConvertLayoutOp(
278+
target,
279+
sg_layout,
280+
sg_data,
281+
inst_data=inst_data,
282+
loc=loc,
283+
ip=ip,
284+
)

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def getDescOpDefaultIndex():
2525
)
2626
with InsertionPoint(sequence.body):
2727
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
28-
desc_handle = xegpu.GetDescOp(operand)
28+
desc_handle = xegpu.get_desc_op(operand)
2929
transform.YieldOp()
3030
# CHECK-LABEL: TEST: getDescOpDefaultIndex
3131
# CHECK: transform.xegpu.get_desc_op %
@@ -39,7 +39,7 @@ def setDescLayoutMinimal():
3939
transform.OperationType.get("xegpu.create_nd_tdesc"),
4040
)
4141
with InsertionPoint(sequence.body):
42-
xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
42+
xegpu.set_desc_layout(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
4343
transform.YieldOp()
4444
# CHECK-LABEL: TEST: setDescLayoutMinimal
4545
# CHECK: %0 = transform.xegpu.set_desc_layout %
@@ -55,7 +55,7 @@ def setDescLayoutInstData():
5555
transform.OperationType.get("xegpu.create_nd_tdesc"),
5656
)
5757
with InsertionPoint(sequence.body):
58-
xegpu.SetDescLayoutOp(
58+
xegpu.set_desc_layout(
5959
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
6060
)
6161
transform.YieldOp()
@@ -74,7 +74,7 @@ def setOpLayoutAttrOperandMinimal():
7474
transform.OperationType.get("xegpu.dpas"),
7575
)
7676
with InsertionPoint(sequence.body):
77-
xegpu.SetOpLayoutAttrOp(
77+
xegpu.set_op_layout_attr(
7878
sequence.bodyTarget,
7979
sg_layout=[6, 4],
8080
sg_data=[32, 16],
@@ -97,7 +97,7 @@ def setOpLayoutAttrResult():
9797
transform.OperationType.get("xegpu.dpas"),
9898
)
9999
with InsertionPoint(sequence.body):
100-
xegpu.SetOpLayoutAttrOp(
100+
xegpu.set_op_layout_attr(
101101
sequence.bodyTarget,
102102
index=0,
103103
sg_layout=[6, 4],
@@ -139,7 +139,7 @@ def ConvertLayoutMinimal():
139139
)
140140
with InsertionPoint(sequence.body):
141141
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
142-
xegpu.ConvertLayoutOp(
142+
xegpu.convert_layout(
143143
operand,
144144
sg_layout=[6, 4],
145145
sg_data=[32, 16],
@@ -160,7 +160,7 @@ def ConvertLayout():
160160
)
161161
with InsertionPoint(sequence.body):
162162
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
163-
xegpu.ConvertLayoutOp(
163+
xegpu.convert_layout(
164164
operand,
165165
sg_layout=[6, 4],
166166
sg_data=[32, 16],

0 commit comments

Comments
 (0)