Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit dfe6aff

Browse files
[mlir][Python] Make PyShapedType public (#106105)
Make `PyShapedType` public, so that downstream projects can define types that implement the `ShapedType` type interface in Python.
1 parent c97fb8b commit dfe6aff

File tree

2 files changed

+122
-89
lines changed

2 files changed

+122
-89
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- IRTypes.h - Type Interfaces ----------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
10+
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
11+
12+
#include "mlir/Bindings/Python/PybindAdaptors.h"
13+
14+
namespace mlir {
15+
16+
/// Shaped Type Interface - ShapedType
17+
class PyShapedType : public python::PyConcreteType<PyShapedType> {
18+
public:
19+
static const IsAFunctionTy isaFunction;
20+
static constexpr const char *pyClassName = "ShapedType";
21+
using PyConcreteType::PyConcreteType;
22+
23+
static void bindDerived(ClassTy &c);
24+
25+
private:
26+
void requireHasRank();
27+
};
28+
29+
} // namespace mlir
30+
31+
#endif // MLIR_BINDINGS_PYTHON_IRTYPES_H

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 91 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include "PybindUtils.h"
1212

13+
#include "mlir/Bindings/Python/IRTypes.h"
14+
1315
#include "mlir-c/BuiltinAttributes.h"
1416
#include "mlir-c/BuiltinTypes.h"
1517
#include "mlir-c/Support.h"
@@ -418,98 +420,98 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
418420
}
419421
};
420422

421-
class PyShapedType : public PyConcreteType<PyShapedType> {
422-
public:
423-
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
424-
static constexpr const char *pyClassName = "ShapedType";
425-
using PyConcreteType::PyConcreteType;
423+
} // namespace
426424

427-
static void bindDerived(ClassTy &c) {
428-
c.def_property_readonly(
429-
"element_type",
430-
[](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
431-
"Returns the element type of the shaped type.");
432-
c.def_property_readonly(
433-
"has_rank",
434-
[](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
435-
"Returns whether the given shaped type is ranked.");
436-
c.def_property_readonly(
437-
"rank",
438-
[](PyShapedType &self) {
439-
self.requireHasRank();
440-
return mlirShapedTypeGetRank(self);
441-
},
442-
"Returns the rank of the given ranked shaped type.");
443-
c.def_property_readonly(
444-
"has_static_shape",
445-
[](PyShapedType &self) -> bool {
446-
return mlirShapedTypeHasStaticShape(self);
447-
},
448-
"Returns whether the given shaped type has a static shape.");
449-
c.def(
450-
"is_dynamic_dim",
451-
[](PyShapedType &self, intptr_t dim) -> bool {
452-
self.requireHasRank();
453-
return mlirShapedTypeIsDynamicDim(self, dim);
454-
},
455-
py::arg("dim"),
456-
"Returns whether the dim-th dimension of the given shaped type is "
457-
"dynamic.");
458-
c.def(
459-
"get_dim_size",
460-
[](PyShapedType &self, intptr_t dim) {
461-
self.requireHasRank();
462-
return mlirShapedTypeGetDimSize(self, dim);
463-
},
464-
py::arg("dim"),
465-
"Returns the dim-th dimension of the given ranked shaped type.");
466-
c.def_static(
467-
"is_dynamic_size",
468-
[](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
469-
py::arg("dim_size"),
470-
"Returns whether the given dimension size indicates a dynamic "
471-
"dimension.");
472-
c.def(
473-
"is_dynamic_stride_or_offset",
474-
[](PyShapedType &self, int64_t val) -> bool {
475-
self.requireHasRank();
476-
return mlirShapedTypeIsDynamicStrideOrOffset(val);
477-
},
478-
py::arg("dim_size"),
479-
"Returns whether the given value is used as a placeholder for dynamic "
480-
"strides and offsets in shaped types.");
481-
c.def_property_readonly(
482-
"shape",
483-
[](PyShapedType &self) {
484-
self.requireHasRank();
485-
486-
std::vector<int64_t> shape;
487-
int64_t rank = mlirShapedTypeGetRank(self);
488-
shape.reserve(rank);
489-
for (int64_t i = 0; i < rank; ++i)
490-
shape.push_back(mlirShapedTypeGetDimSize(self, i));
491-
return shape;
492-
},
493-
"Returns the shape of the ranked shaped type as a list of integers.");
494-
c.def_static(
495-
"get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
496-
"Returns the value used to indicate dynamic dimensions in shaped "
497-
"types.");
498-
c.def_static(
499-
"get_dynamic_stride_or_offset",
500-
[]() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
501-
"Returns the value used to indicate dynamic strides or offsets in "
502-
"shaped types.");
503-
}
425+
// Shaped Type Interface - ShapedType
426+
void mlir::PyShapedType::bindDerived(ClassTy &c) {
427+
c.def_property_readonly(
428+
"element_type",
429+
[](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
430+
"Returns the element type of the shaped type.");
431+
c.def_property_readonly(
432+
"has_rank",
433+
[](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
434+
"Returns whether the given shaped type is ranked.");
435+
c.def_property_readonly(
436+
"rank",
437+
[](PyShapedType &self) {
438+
self.requireHasRank();
439+
return mlirShapedTypeGetRank(self);
440+
},
441+
"Returns the rank of the given ranked shaped type.");
442+
c.def_property_readonly(
443+
"has_static_shape",
444+
[](PyShapedType &self) -> bool {
445+
return mlirShapedTypeHasStaticShape(self);
446+
},
447+
"Returns whether the given shaped type has a static shape.");
448+
c.def(
449+
"is_dynamic_dim",
450+
[](PyShapedType &self, intptr_t dim) -> bool {
451+
self.requireHasRank();
452+
return mlirShapedTypeIsDynamicDim(self, dim);
453+
},
454+
py::arg("dim"),
455+
"Returns whether the dim-th dimension of the given shaped type is "
456+
"dynamic.");
457+
c.def(
458+
"get_dim_size",
459+
[](PyShapedType &self, intptr_t dim) {
460+
self.requireHasRank();
461+
return mlirShapedTypeGetDimSize(self, dim);
462+
},
463+
py::arg("dim"),
464+
"Returns the dim-th dimension of the given ranked shaped type.");
465+
c.def_static(
466+
"is_dynamic_size",
467+
[](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
468+
py::arg("dim_size"),
469+
"Returns whether the given dimension size indicates a dynamic "
470+
"dimension.");
471+
c.def(
472+
"is_dynamic_stride_or_offset",
473+
[](PyShapedType &self, int64_t val) -> bool {
474+
self.requireHasRank();
475+
return mlirShapedTypeIsDynamicStrideOrOffset(val);
476+
},
477+
py::arg("dim_size"),
478+
"Returns whether the given value is used as a placeholder for dynamic "
479+
"strides and offsets in shaped types.");
480+
c.def_property_readonly(
481+
"shape",
482+
[](PyShapedType &self) {
483+
self.requireHasRank();
484+
485+
std::vector<int64_t> shape;
486+
int64_t rank = mlirShapedTypeGetRank(self);
487+
shape.reserve(rank);
488+
for (int64_t i = 0; i < rank; ++i)
489+
shape.push_back(mlirShapedTypeGetDimSize(self, i));
490+
return shape;
491+
},
492+
"Returns the shape of the ranked shaped type as a list of integers.");
493+
c.def_static(
494+
"get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
495+
"Returns the value used to indicate dynamic dimensions in shaped "
496+
"types.");
497+
c.def_static(
498+
"get_dynamic_stride_or_offset",
499+
[]() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
500+
"Returns the value used to indicate dynamic strides or offsets in "
501+
"shaped types.");
502+
}
504503

505-
private:
506-
void requireHasRank() {
507-
if (!mlirShapedTypeHasRank(*this)) {
508-
throw py::value_error(
509-
"calling this method requires that the type has a rank.");
510-
}
504+
void mlir::PyShapedType::requireHasRank() {
505+
if (!mlirShapedTypeHasRank(*this)) {
506+
throw py::value_error(
507+
"calling this method requires that the type has a rank.");
511508
}
512-
};
509+
}
510+
511+
const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction =
512+
mlirTypeIsAShaped;
513+
514+
namespace {
513515

514516
/// Vector Type subclass - VectorType.
515517
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {

0 commit comments

Comments
 (0)