Skip to content

Commit 1456dfa

Browse files
authored
Add DotAlgorithm to StableHLO Python API (#2521)
1 parent 21dcdd2 commit 1456dfa

File tree

4 files changed

+169
-0
lines changed

4 files changed

+169
-0
lines changed

stablehlo/integrations/c/StablehloAttributes.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,60 @@ int64_t stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr) {
212212
.getIndexVectorDim();
213213
}
214214

215+
//===----------------------------------------------------------------------===//
216+
// DotAlgorithm
217+
//===----------------------------------------------------------------------===//
218+
219+
MlirAttribute stablehloDotAlgorithmGet(
220+
MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType,
221+
MlirType accumulationType, int64_t lhsComponentCount,
222+
int64_t rhsComponentCount, int64_t numPrimitiveOperations,
223+
bool allowImpreciseAccumulation) {
224+
return wrap(mlir::stablehlo::DotAlgorithmAttr::get(
225+
unwrap(ctx), unwrap(lhsPrecisionType), unwrap(rhsPrecisionType),
226+
unwrap(accumulationType), lhsComponentCount, rhsComponentCount,
227+
numPrimitiveOperations, allowImpreciseAccumulation));
228+
}
229+
230+
bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr) {
231+
return llvm::isa<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr));
232+
}
233+
234+
MlirType stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr) {
235+
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
236+
.getLhsPrecisionType());
237+
}
238+
239+
MlirType stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr) {
240+
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
241+
.getRhsPrecisionType());
242+
}
243+
244+
MlirType stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr) {
245+
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
246+
.getAccumulationType());
247+
}
248+
249+
int64_t stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr) {
250+
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
251+
.getLhsComponentCount();
252+
}
253+
254+
int64_t stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr) {
255+
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
256+
.getRhsComponentCount();
257+
}
258+
259+
int64_t stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr) {
260+
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
261+
.getNumPrimitiveOperations();
262+
}
263+
264+
bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(MlirAttribute attr) {
265+
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
266+
.getAllowImpreciseAccumulation();
267+
}
268+
215269
//===----------------------------------------------------------------------===//
216270
// DotDimensionNumbers
217271
//===----------------------------------------------------------------------===//

stablehlo/integrations/c/StablehloAttributes.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,39 @@ MLIR_CAPI_EXPORTED int64_t stablehloGatherDimensionNumbersGetStartIndexMapElem(
113113
MLIR_CAPI_EXPORTED int64_t
114114
stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr);
115115

116+
//===----------------------------------------------------------------------===//
117+
// DotAlgorithm
118+
//===----------------------------------------------------------------------===//
119+
120+
MLIR_CAPI_EXPORTED MlirAttribute stablehloDotAlgorithmGet(
121+
MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType,
122+
MlirType accumulationType, int64_t lhsComponentCount,
123+
int64_t rhsComponentCount, int64_t numPrimitiveOperations,
124+
bool allowImpreciseAccumulation);
125+
126+
MLIR_CAPI_EXPORTED bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr);
127+
128+
MLIR_CAPI_EXPORTED MlirType
129+
stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr);
130+
131+
MLIR_CAPI_EXPORTED MlirType
132+
stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr);
133+
134+
MLIR_CAPI_EXPORTED MlirType
135+
stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr);
136+
137+
MLIR_CAPI_EXPORTED int64_t
138+
stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr);
139+
140+
MLIR_CAPI_EXPORTED int64_t
141+
stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr);
142+
143+
MLIR_CAPI_EXPORTED int64_t
144+
stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr);
145+
146+
MLIR_CAPI_EXPORTED bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(
147+
MlirAttribute attr);
148+
116149
//===----------------------------------------------------------------------===//
117150
// DotDimensionNumbers
118151
//===----------------------------------------------------------------------===//

stablehlo/integrations/python/StablehloModule.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,62 @@ PYBIND11_MODULE(_stablehlo, m) {
220220
return stablehloGatherDimensionNumbersGetIndexVectorDim(self);
221221
});
222222

223+
mlir::python::adaptors::mlir_attribute_subclass(
224+
m, "DotAlgorithm", stablehloAttributeIsADotAlgorithm)
225+
.def_classmethod(
226+
"get",
227+
[](py::object cls, MlirType lhsPrecisionType,
228+
MlirType rhsPrecisionType, MlirType accumulationType,
229+
int64_t lhsComponentCount, int64_t rhsComponentCount,
230+
int64_t numPrimitiveOperations, bool allowImpreciseAccumulation,
231+
MlirContext ctx) {
232+
return cls(stablehloDotAlgorithmGet(
233+
ctx, lhsPrecisionType, rhsPrecisionType, accumulationType,
234+
lhsComponentCount, rhsComponentCount, numPrimitiveOperations,
235+
allowImpreciseAccumulation));
236+
},
237+
py::arg("cls"), py::arg("lhs_precision_type"),
238+
py::arg("rhs_precision_type"), py::arg("accumulation_type"),
239+
py::arg("lhs_component_count"), py::arg("rhs_component_count"),
240+
py::arg("num_primitive_operations"),
241+
py::arg("allow_imprecise_accumulation"), py::arg("ctx") = py::none(),
242+
"Creates a DotAlgorithm attribute with the given dimension "
243+
"configuration.")
244+
.def_property_readonly(
245+
"lhs_precision_type",
246+
[](MlirAttribute self) {
247+
return stablehloDotAlgorithmGetLhsPrecisionType(self);
248+
})
249+
.def_property_readonly(
250+
"rhs_precision_type",
251+
[](MlirAttribute self) {
252+
return stablehloDotAlgorithmGetRhsPrecisionType(self);
253+
})
254+
.def_property_readonly(
255+
"accumulation_type",
256+
[](MlirAttribute self) {
257+
return stablehloDotAlgorithmGetAccumulationType(self);
258+
})
259+
.def_property_readonly(
260+
"lhs_component_count",
261+
[](MlirAttribute self) {
262+
return stablehloDotAlgorithmGetLhsComponentCount(self);
263+
})
264+
.def_property_readonly(
265+
"rhs_component_count",
266+
[](MlirAttribute self) {
267+
return stablehloDotAlgorithmGetRhsComponentCount(self);
268+
})
269+
.def_property_readonly(
270+
"num_primitive_operations",
271+
[](MlirAttribute self) {
272+
return stablehloDotAlgorithmGetNumPrimitiveOperations(self);
273+
})
274+
.def_property_readonly(
275+
"allow_imprecise_accumulation", [](MlirAttribute self) {
276+
return stablehloDotAlgorithmGetAllowImpreciseAccumulation(self);
277+
});
278+
223279
mlir::python::adaptors::mlir_attribute_subclass(
224280
m, "DotDimensionNumbers", stablehloAttributeIsADotDimensionNumbers)
225281
.def_classmethod(

stablehlo/integrations/python/tests/stablehlo.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,32 @@ def test_conv_dimension_numbers():
8282
assert attr.output_spatial_dimensions == [2, 3]
8383

8484

85+
@run
86+
def test_dot_algorithm():
87+
# BF16_BF16_F32_X3
88+
attr = stablehlo.DotAlgorithm.get(
89+
lhs_precision_type=ir.BF16Type.get(),
90+
rhs_precision_type=ir.BF16Type.get(),
91+
accumulation_type=ir.F32Type.get(),
92+
lhs_component_count=1,
93+
rhs_component_count=1,
94+
num_primitive_operations=3,
95+
allow_imprecise_accumulation=False)
96+
assert attr is not None
97+
assert str(attr) == ("#stablehlo.dot_algorithm<lhs_precision_type = bf16, "
98+
"rhs_precision_type = bf16, accumulation_type = f32, "
99+
"lhs_component_count = 1, rhs_component_count = 1, "
100+
"num_primitive_operations = 3, "
101+
"allow_imprecise_accumulation = false>")
102+
assert isinstance(attr.lhs_precision_type, ir.BF16Type)
103+
assert isinstance(attr.rhs_precision_type, ir.BF16Type)
104+
assert isinstance(attr.accumulation_type, ir.F32Type)
105+
assert attr.lhs_component_count == 1
106+
assert attr.rhs_component_count == 1
107+
assert attr.num_primitive_operations == 3
108+
assert attr.allow_imprecise_accumulation == False
109+
110+
85111
@run
86112
def test_dot_dimension_numbers():
87113
attr = stablehlo.DotDimensionNumbers.get(

0 commit comments

Comments
 (0)