Skip to content

Commit 412ce73

Browse files
bangtianliukeshavvinayak01
authored andcommitted
[Codegen][Tuner] add python binding for VirtualMMAIntrinsic (iree-org#21403)
This PR exposes IREE's `VirtualMMAIntrinsicAttr` and `VirtualMMAAttr` to python. What's more, This PR also expose python binding MMAAttr::getVirtualIntrinsics() for the function `mma.getVirtualIntrinsics()` from IREE C++ end, which returns a list of associated VirtualMMAIntrinsic enums. As a result, the tuner can now enumerate all MMA options including the virtual ones consistent with how they are handled what is in KernelConfig.cpp: https://github.com/iree-org/iree/blob/af56a4793d8e9eb0dca7b7bfc7f41ec50d4d23c1/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp#L1376-L1387 Issue: nod-ai/amd-shark-ai#1769 --------- Signed-off-by: Bangtian Liu <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 277807c commit 412ce73

File tree

8 files changed

+288
-13
lines changed

8 files changed

+288
-13
lines changed

compiler/bindings/c/iree/compiler/dialects/iree_gpu.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ MLIR_CAPI_EXPORTED MlirTypeID ireeGPUMMAAttrGetTypeID(void);
6868
MLIR_CAPI_EXPORTED MlirAttribute ireeGPUMMAAttrGet(MlirContext mlirCtx,
6969
uint32_t value);
7070

71+
MLIR_CAPI_EXPORTED bool
72+
ireeAttributeIsAGPUVirtualMMAIntrinsicAttr(MlirAttribute attr);
73+
74+
MLIR_CAPI_EXPORTED MlirTypeID ireeGPUVirtualMMAIntrinsicAttrGetTypeID(void);
75+
76+
MLIR_CAPI_EXPORTED MlirAttribute
77+
ireeGPUVirtualMMAIntrinsicAttrGet(MlirContext mlirCtx, uint32_t value);
78+
79+
MLIR_CAPI_EXPORTED uint32_t
80+
ireeGPUVirtualMMAIntrinsicAttrGetValue(MlirAttribute attr);
81+
82+
MLIR_CAPI_EXPORTED bool ireeAttributeIsAGPUVirtualMMAAttr(MlirAttribute attr);
83+
84+
MLIR_CAPI_EXPORTED MlirTypeID ireeGPUVirtualMMAAttrGetTypeID(void);
85+
86+
MLIR_CAPI_EXPORTED MlirAttribute ireeGPUVirtualMMAAttrGet(MlirContext mlirCtx,
87+
uint32_t value);
7188
struct ireeGPUMMAInfo {
7289
MlirType aElementType;
7390
MlirType bElementType;
@@ -82,6 +99,9 @@ struct ireeGPUMMAInfo {
8299

83100
MLIR_CAPI_EXPORTED ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr);
84101

102+
MLIR_CAPI_EXPORTED MlirAttribute
103+
ireeGPUMMAAttrGetVirtualMMAIntrinsic(MlirAttribute attr);
104+
85105
MLIR_CAPI_EXPORTED bool
86106
ireeAttributeIsAGPULoweringConfigAttr(MlirAttribute attr);
87107

compiler/bindings/python/IREECompilerDialectsModule.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,95 @@ NB_MODULE(_ireeCompilerDialects, m) {
357357
return py::make_tuple(info.aVectorType, info.bVectorType,
358358
info.cVectorType);
359359
})
360+
.def_property_readonly(
361+
"mnk_shape",
362+
[](MlirAttribute self) -> py::tuple {
363+
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
364+
return py::make_tuple(info.mElements, info.nElements,
365+
info.kElements);
366+
})
367+
.def(
368+
"get_virtual_intrinsics",
369+
[](MlirAttribute self) {
370+
MlirAttribute rawArrayAttr =
371+
ireeGPUMMAAttrGetVirtualMMAIntrinsic(self);
372+
if (mlirAttributeIsNull(rawArrayAttr)) {
373+
return std::vector<py::object>{};
374+
}
375+
376+
auto arrayAttr = mlir::cast<mlir::ArrayAttr>(unwrap(rawArrayAttr));
377+
static py::object virtualEnum =
378+
py::module_::import_(kGpuModuleImportPath)
379+
.attr("VirtualMMAIntrinsic");
380+
381+
std::vector<py::object> result;
382+
for (mlir::Attribute attr : arrayAttr) {
383+
auto intAttr = mlir::cast<mlir::IntegerAttr>(attr);
384+
result.push_back(
385+
virtualEnum(static_cast<uint32_t>(intAttr.getInt())));
386+
}
387+
388+
return result;
389+
},
390+
"Returns a list of virtual intrinsics associated with this "
391+
"MMAAttr.");
392+
393+
//===-------------------------------------------------------------------===//
394+
// GPUVirtualMMAIntrinsicAttr
395+
//===-------------------------------------------------------------------===//
396+
397+
mlir_attribute_subclass(iree_gpu_module, "VirtualMMAIntrinsicAttr",
398+
ireeAttributeIsAGPUVirtualMMAIntrinsicAttr,
399+
ireeGPUVirtualMMAIntrinsicAttrGetTypeID)
400+
.def_classmethod(
401+
"get",
402+
[](const py::object &, uint32_t value, MlirContext ctx) {
403+
return ireeGPUVirtualMMAIntrinsicAttrGet(ctx, value);
404+
},
405+
"cls"_a, "value"_a, "ctx"_a = py::none(),
406+
"Gets an #iree_gpu.virtual_mma_intrinsic from parameters.")
407+
.def_property_readonly("raw_value",
408+
ireeGPUVirtualMMAIntrinsicAttrGetValue)
409+
.def_property_readonly("value",
410+
[](MlirAttribute self) -> py::object {
411+
uint32_t rawValue =
412+
ireeGPUVirtualMMAIntrinsicAttrGetValue(self);
413+
return py::module_::import_(kGpuModuleImportPath)
414+
.attr("VirtualMMAIntrinsic")(rawValue);
415+
})
416+
.def_property_readonly("mma", [](MlirAttribute self) -> MlirAttribute {
417+
uint32_t value = ireeGPUVirtualMMAIntrinsicAttrGetValue(self);
418+
return ireeGPUVirtualMMAAttrGet(mlirAttributeGetContext(self), value);
419+
});
420+
421+
//===-------------------------------------------------------------------===//
422+
// GPUVirtualMMAAttr
423+
//===-------------------------------------------------------------------===//
424+
425+
mlir_attribute_subclass(iree_gpu_module, "VirtualMMAAttr",
426+
ireeAttributeIsAGPUVirtualMMAAttr,
427+
ireeGPUVirtualMMAAttrGetTypeID)
428+
.def_classmethod(
429+
"get",
430+
[](const py::object &, uint32_t value, MlirContext ctx) {
431+
return ireeGPUVirtualMMAAttrGet(ctx, value);
432+
},
433+
"cls"_a, "value"_a, "ctx"_a = py::none(),
434+
"Gets an #iree_gpu.virtualmma from parameters.")
435+
.def_property_readonly(
436+
"abc_element_types",
437+
[](MlirAttribute self) -> py::tuple {
438+
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
439+
return py::make_tuple(info.aElementType, info.bElementType,
440+
info.cElementType);
441+
})
442+
.def_property_readonly(
443+
"abc_vector_types",
444+
[](MlirAttribute self) -> py::tuple {
445+
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
446+
return py::make_tuple(info.aVectorType, info.bVectorType,
447+
info.cVectorType);
448+
})
360449
.def_property_readonly("mnk_shape", [](MlirAttribute self) -> py::tuple {
361450
ireeGPUMMAInfo info = ireeGPUMMAAttrGetInfo(self);
362451
return py::make_tuple(info.mElements, info.nElements, info.kElements);

compiler/bindings/python/test/ir/dialects_test.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,11 @@ def mma_intrinsic_attr():
245245
assert c_type == f32
246246

247247
vec_4xf16 = ir.VectorType.get((4,), f16)
248-
a_vec_type, b_vec_type, _c_vec_type = mma_attr.abc_vector_types
248+
vec_16xf32 = ir.VectorType.get((16,), f32)
249+
a_vec_type, b_vec_type, c_vec_type = mma_attr.abc_vector_types
249250
assert a_vec_type == vec_4xf16
250251
assert b_vec_type == vec_4xf16
252+
assert c_vec_type == vec_16xf32
251253

252254
M, N, K = mma_attr.mnk_shape
253255
assert M == 32
@@ -256,6 +258,58 @@ def mma_intrinsic_attr():
256258

257259
assert mma_intrinsic_attr.mma == mma_attr
258260

261+
virtual_mma_intrinsics = mma_attr.get_virtual_intrinsics()
262+
assert isinstance(virtual_mma_intrinsics[0], iree_gpu.VirtualMMAIntrinsic)
263+
assert (
264+
virtual_mma_intrinsics[0] == iree_gpu.VirtualMMAIntrinsic.VMFMA_F32_32x32x16_F16
265+
)
266+
267+
mma_attr = iree_gpu.MMAAttr.get(iree_gpu.MMAIntrinsic.MFMA_F32_16x16x4_F32)
268+
virtual_mma_intrinsics = mma_attr.get_virtual_intrinsics()
269+
assert virtual_mma_intrinsics == []
270+
271+
272+
@run
273+
def virtual_mma_intrinsic_attr():
274+
virtual_mma_intrinsic_attr = iree_gpu.VirtualMMAIntrinsicAttr.get(
275+
iree_gpu.VirtualMMAIntrinsic.VMFMA_F32_16x16x32_F16
276+
)
277+
assert virtual_mma_intrinsic_attr is not None
278+
assert (
279+
str(virtual_mma_intrinsic_attr)
280+
== "#iree_gpu<virtual_mma_intrinsic VMFMA_F32_16x16x32_F16>"
281+
)
282+
283+
raw_value = virtual_mma_intrinsic_attr.raw_value
284+
assert raw_value == iree_gpu.VirtualMMAIntrinsic.VMFMA_F32_16x16x32_F16
285+
value = virtual_mma_intrinsic_attr.value
286+
assert str(value) == "VMFMA_F32_16x16x32_F16"
287+
assert int(value) == raw_value
288+
289+
virtual_mma_attr = iree_gpu.VirtualMMAAttr.get(raw_value)
290+
assert virtual_mma_attr is not None
291+
292+
f16 = ir.F16Type.get()
293+
f32 = ir.F32Type.get()
294+
a_type, b_type, c_type = virtual_mma_attr.abc_element_types
295+
assert a_type == f16
296+
assert b_type == f16
297+
assert c_type == f32
298+
299+
vec_4xf32 = ir.VectorType.get((4,), f32)
300+
vec_8xf16 = ir.VectorType.get((8,), f16)
301+
a_vec_type, b_vec_type, c_vec_type = virtual_mma_attr.abc_vector_types
302+
assert a_vec_type == vec_8xf16
303+
assert b_vec_type == vec_8xf16
304+
assert c_vec_type == vec_4xf32
305+
306+
M, N, K = virtual_mma_attr.mnk_shape
307+
assert M == 16
308+
assert N == 16
309+
assert K == 32
310+
311+
assert virtual_mma_intrinsic_attr.mma == virtual_mma_attr
312+
259313

260314
@run
261315
def lowering_config_attr():

compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,21 +133,45 @@ bool ireeAttributeIsAGPUMMAIntrinsicAttr(MlirAttribute attr) {
133133
unwrap(attr));
134134
}
135135

136+
bool ireeAttributeIsAGPUVirtualMMAIntrinsicAttr(MlirAttribute attr) {
137+
return llvm::isa<mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr>(
138+
unwrap(attr));
139+
}
140+
136141
MlirTypeID ireeGPUMMAIntrinsicAttrGetTypeID() {
137142
return wrap(mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr::getTypeID());
138143
}
139144

145+
MlirTypeID ireeGPUVirtualMMAIntrinsicAttrGetTypeID() {
146+
return wrap(
147+
mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr::getTypeID());
148+
}
149+
140150
static_assert(
141151
std::is_same_v<uint32_t, std::underlying_type_t<
142152
mlir::iree_compiler::IREE::GPU::MMAIntrinsic>>,
143153
"Enum type changed");
144154

155+
static_assert(
156+
std::is_same_v<uint32_t,
157+
std::underlying_type_t<
158+
mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic>>,
159+
"Enum type changed");
160+
145161
MlirAttribute ireeGPUMMAIntrinsicAttrGet(MlirContext mlirCtx, uint32_t value) {
146162
mlir::MLIRContext *ctx = unwrap(mlirCtx);
147163
return wrap(mlir::iree_compiler::IREE::GPU::MMAIntrinsicAttr::get(
148164
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
149165
}
150166

167+
MlirAttribute ireeGPUVirtualMMAIntrinsicAttrGet(MlirContext mlirCtx,
168+
uint32_t value) {
169+
mlir::MLIRContext *ctx = unwrap(mlirCtx);
170+
return wrap(mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr::get(
171+
ctx,
172+
static_cast<mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic>(value)));
173+
}
174+
151175
uint32_t ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr) {
152176
assert(ireeAttributeIsAGPUMMAIntrinsicAttr(attr) &&
153177
"attr is not a GPUMMAIntrinsicAttr");
@@ -156,37 +180,85 @@ uint32_t ireeGPUMMAIntrinsicAttrGetValue(MlirAttribute attr) {
156180
.getValue());
157181
}
158182

183+
uint32_t ireeGPUVirtualMMAIntrinsicAttrGetValue(MlirAttribute attr) {
184+
assert(ireeAttributeIsAGPUVirtualMMAIntrinsicAttr(attr) &&
185+
"attr is not a GPUVirtualMMAIntrinsicAttr");
186+
return static_cast<uint32_t>(
187+
llvm::cast<mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsicAttr>(
188+
unwrap(attr))
189+
.getValue());
190+
}
191+
159192
bool ireeAttributeIsAGPUMMAAttr(MlirAttribute attr) {
160193
return llvm::isa<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));
161194
}
162195

196+
bool ireeAttributeIsAGPUVirtualMMAAttr(MlirAttribute attr) {
197+
return llvm::isa<mlir::iree_compiler::IREE::GPU::VirtualMMAAttr>(
198+
unwrap(attr));
199+
}
200+
163201
MlirTypeID ireeGPUMMAAttrGetTypeID() {
164202
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::getTypeID());
165203
}
166204

205+
MlirTypeID ireeGPUVirtualMMAAttrGetTypeID() {
206+
return wrap(mlir::iree_compiler::IREE::GPU::VirtualMMAAttr::getTypeID());
207+
}
208+
167209
MlirAttribute ireeGPUMMAAttrGet(MlirContext mlirCtx, uint32_t value) {
168210
mlir::MLIRContext *ctx = unwrap(mlirCtx);
169211
return wrap(mlir::iree_compiler::IREE::GPU::MMAAttr::get(
170212
ctx, static_cast<mlir::iree_compiler::IREE::GPU::MMAIntrinsic>(value)));
171213
}
172214

215+
MlirAttribute ireeGPUVirtualMMAAttrGet(MlirContext mlirCtx, uint32_t value) {
216+
mlir::MLIRContext *ctx = unwrap(mlirCtx);
217+
return wrap(mlir::iree_compiler::IREE::GPU::VirtualMMAAttr::get(
218+
ctx,
219+
static_cast<mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic>(value)));
220+
}
221+
173222
ireeGPUMMAInfo ireeGPUMMAAttrGetInfo(MlirAttribute attr) {
223+
return llvm::TypeSwitch<mlir::Attribute, ireeGPUMMAInfo>(unwrap(attr))
224+
.Case<mlir::iree_compiler::IREE::GPU::MMAAttr,
225+
mlir::iree_compiler::IREE::GPU::VirtualMMAAttr>([](auto mma) {
226+
ireeGPUMMAInfo info = {};
227+
auto [aType, bType, cType] = mma.getABCElementTypes();
228+
info.aElementType = wrap(aType);
229+
info.bElementType = wrap(bType);
230+
info.cElementType = wrap(cType);
231+
232+
auto [aVecType, bVecType, cVecType] = mma.getABCVectorTypes();
233+
info.aVectorType = wrap(aVecType);
234+
info.bVectorType = wrap(bVecType);
235+
info.cVectorType = wrap(cVecType);
236+
237+
std::tie(info.mElements, info.nElements, info.kElements) =
238+
mma.getMNKShape();
239+
240+
return info;
241+
})
242+
.Default([](mlir::Attribute) -> ireeGPUMMAInfo {
243+
assert(false && "Unexpected attribute type for MMA info");
244+
return {};
245+
});
246+
}
247+
248+
MlirAttribute ireeGPUMMAAttrGetVirtualMMAIntrinsic(MlirAttribute attr) {
174249
assert(ireeAttributeIsAGPUMMAAttr(attr) && "attr is not a MMAAttr");
175250
auto mma = llvm::cast<mlir::iree_compiler::IREE::GPU::MMAAttr>(unwrap(attr));
251+
llvm::SmallVector<mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic>
252+
virtualIntrinsics = mma.getVirtualIntrinsics();
176253

177-
ireeGPUMMAInfo info = {};
178-
auto [aType, bType, cType] = mma.getABCElementTypes();
179-
info.aElementType = wrap(aType);
180-
info.bElementType = wrap(bType);
181-
info.cElementType = wrap(cType);
182-
183-
auto [aVecType, bVecType, cVecType] = mma.getABCVectorTypes();
184-
info.aVectorType = wrap(aVecType);
185-
info.bVectorType = wrap(bVecType);
186-
info.cVectorType = wrap(cVecType);
254+
llvm::SmallVector<int64_t> rawValues;
255+
for (auto v : virtualIntrinsics) {
256+
rawValues.push_back(static_cast<int64_t>(v));
257+
}
187258

188-
std::tie(info.mElements, info.nElements, info.kElements) = mma.getMNKShape();
189-
return info;
259+
mlir::MLIRContext *ctx = mma.getContext();
260+
mlir::Builder builder(ctx);
261+
return wrap(builder.getI64ArrayAttr(rawValues));
190262
}
191263

192264
bool ireeAttributeIsAGPULoweringConfigAttr(MlirAttribute attr) {

0 commit comments

Comments
 (0)