|
10 | 10 |
|
11 | 11 | #include "PybindUtils.h" |
12 | 12 |
|
| 13 | +#include "mlir/Bindings/Python/IRTypes.h" |
| 14 | + |
13 | 15 | #include "mlir-c/BuiltinAttributes.h" |
14 | 16 | #include "mlir-c/BuiltinTypes.h" |
15 | 17 | #include "mlir-c/Support.h" |
@@ -418,98 +420,98 @@ class PyComplexType : public PyConcreteType<PyComplexType> { |
418 | 420 | } |
419 | 421 | }; |
420 | 422 |
|
421 | | -class PyShapedType : public PyConcreteType<PyShapedType> { |
422 | | -public: |
423 | | - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; |
424 | | - static constexpr const char *pyClassName = "ShapedType"; |
425 | | - using PyConcreteType::PyConcreteType; |
| 423 | +} // namespace |
426 | 424 |
|
427 | | - static void bindDerived(ClassTy &c) { |
428 | | - c.def_property_readonly( |
429 | | - "element_type", |
430 | | - [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, |
431 | | - "Returns the element type of the shaped type."); |
432 | | - c.def_property_readonly( |
433 | | - "has_rank", |
434 | | - [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, |
435 | | - "Returns whether the given shaped type is ranked."); |
436 | | - c.def_property_readonly( |
437 | | - "rank", |
438 | | - [](PyShapedType &self) { |
439 | | - self.requireHasRank(); |
440 | | - return mlirShapedTypeGetRank(self); |
441 | | - }, |
442 | | - "Returns the rank of the given ranked shaped type."); |
443 | | - c.def_property_readonly( |
444 | | - "has_static_shape", |
445 | | - [](PyShapedType &self) -> bool { |
446 | | - return mlirShapedTypeHasStaticShape(self); |
447 | | - }, |
448 | | - "Returns whether the given shaped type has a static shape."); |
449 | | - c.def( |
450 | | - "is_dynamic_dim", |
451 | | - [](PyShapedType &self, intptr_t dim) -> bool { |
452 | | - self.requireHasRank(); |
453 | | - return mlirShapedTypeIsDynamicDim(self, dim); |
454 | | - }, |
455 | | - py::arg("dim"), |
456 | | - "Returns whether the dim-th dimension of the given shaped type is " |
457 | | - "dynamic."); |
458 | | - c.def( |
459 | | - "get_dim_size", |
460 | | - [](PyShapedType &self, intptr_t dim) { |
461 | | - self.requireHasRank(); |
462 | | - return mlirShapedTypeGetDimSize(self, dim); |
463 | | - }, |
464 | | - py::arg("dim"), |
465 | | - "Returns the dim-th dimension of the given ranked shaped type."); |
466 | | - c.def_static( |
467 | | - "is_dynamic_size", |
468 | | - [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, |
469 | | - py::arg("dim_size"), |
470 | | - "Returns whether the given dimension size indicates a dynamic " |
471 | | - "dimension."); |
472 | | - c.def( |
473 | | - "is_dynamic_stride_or_offset", |
474 | | - [](PyShapedType &self, int64_t val) -> bool { |
475 | | - self.requireHasRank(); |
476 | | - return mlirShapedTypeIsDynamicStrideOrOffset(val); |
477 | | - }, |
478 | | - py::arg("dim_size"), |
479 | | - "Returns whether the given value is used as a placeholder for dynamic " |
480 | | - "strides and offsets in shaped types."); |
481 | | - c.def_property_readonly( |
482 | | - "shape", |
483 | | - [](PyShapedType &self) { |
484 | | - self.requireHasRank(); |
485 | | - |
486 | | - std::vector<int64_t> shape; |
487 | | - int64_t rank = mlirShapedTypeGetRank(self); |
488 | | - shape.reserve(rank); |
489 | | - for (int64_t i = 0; i < rank; ++i) |
490 | | - shape.push_back(mlirShapedTypeGetDimSize(self, i)); |
491 | | - return shape; |
492 | | - }, |
493 | | - "Returns the shape of the ranked shaped type as a list of integers."); |
494 | | - c.def_static( |
495 | | - "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, |
496 | | - "Returns the value used to indicate dynamic dimensions in shaped " |
497 | | - "types."); |
498 | | - c.def_static( |
499 | | - "get_dynamic_stride_or_offset", |
500 | | - []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, |
501 | | - "Returns the value used to indicate dynamic strides or offsets in " |
502 | | - "shaped types."); |
503 | | - } |
| 425 | +// Shaped Type Interface - ShapedType |
| 426 | +void mlir::PyShapedType::bindDerived(ClassTy &c) { |
| 427 | + c.def_property_readonly( |
| 428 | + "element_type", |
| 429 | + [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, |
| 430 | + "Returns the element type of the shaped type."); |
| 431 | + c.def_property_readonly( |
| 432 | + "has_rank", |
| 433 | + [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, |
| 434 | + "Returns whether the given shaped type is ranked."); |
| 435 | + c.def_property_readonly( |
| 436 | + "rank", |
| 437 | + [](PyShapedType &self) { |
| 438 | + self.requireHasRank(); |
| 439 | + return mlirShapedTypeGetRank(self); |
| 440 | + }, |
| 441 | + "Returns the rank of the given ranked shaped type."); |
| 442 | + c.def_property_readonly( |
| 443 | + "has_static_shape", |
| 444 | + [](PyShapedType &self) -> bool { |
| 445 | + return mlirShapedTypeHasStaticShape(self); |
| 446 | + }, |
| 447 | + "Returns whether the given shaped type has a static shape."); |
| 448 | + c.def( |
| 449 | + "is_dynamic_dim", |
| 450 | + [](PyShapedType &self, intptr_t dim) -> bool { |
| 451 | + self.requireHasRank(); |
| 452 | + return mlirShapedTypeIsDynamicDim(self, dim); |
| 453 | + }, |
| 454 | + py::arg("dim"), |
| 455 | + "Returns whether the dim-th dimension of the given shaped type is " |
| 456 | + "dynamic."); |
| 457 | + c.def( |
| 458 | + "get_dim_size", |
| 459 | + [](PyShapedType &self, intptr_t dim) { |
| 460 | + self.requireHasRank(); |
| 461 | + return mlirShapedTypeGetDimSize(self, dim); |
| 462 | + }, |
| 463 | + py::arg("dim"), |
| 464 | + "Returns the dim-th dimension of the given ranked shaped type."); |
| 465 | + c.def_static( |
| 466 | + "is_dynamic_size", |
| 467 | + [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, |
| 468 | + py::arg("dim_size"), |
| 469 | + "Returns whether the given dimension size indicates a dynamic " |
| 470 | + "dimension."); |
| 471 | + c.def( |
| 472 | + "is_dynamic_stride_or_offset", |
| 473 | + [](PyShapedType &self, int64_t val) -> bool { |
| 474 | + self.requireHasRank(); |
| 475 | + return mlirShapedTypeIsDynamicStrideOrOffset(val); |
| 476 | + }, |
| 477 | + py::arg("dim_size"), |
| 478 | + "Returns whether the given value is used as a placeholder for dynamic " |
| 479 | + "strides and offsets in shaped types."); |
| 480 | + c.def_property_readonly( |
| 481 | + "shape", |
| 482 | + [](PyShapedType &self) { |
| 483 | + self.requireHasRank(); |
| 484 | + |
| 485 | + std::vector<int64_t> shape; |
| 486 | + int64_t rank = mlirShapedTypeGetRank(self); |
| 487 | + shape.reserve(rank); |
| 488 | + for (int64_t i = 0; i < rank; ++i) |
| 489 | + shape.push_back(mlirShapedTypeGetDimSize(self, i)); |
| 490 | + return shape; |
| 491 | + }, |
| 492 | + "Returns the shape of the ranked shaped type as a list of integers."); |
| 493 | + c.def_static( |
| 494 | + "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, |
| 495 | + "Returns the value used to indicate dynamic dimensions in shaped " |
| 496 | + "types."); |
| 497 | + c.def_static( |
| 498 | + "get_dynamic_stride_or_offset", |
| 499 | + []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, |
| 500 | + "Returns the value used to indicate dynamic strides or offsets in " |
| 501 | + "shaped types."); |
| 502 | +} |
504 | 503 |
|
505 | | -private: |
506 | | - void requireHasRank() { |
507 | | - if (!mlirShapedTypeHasRank(*this)) { |
508 | | - throw py::value_error( |
509 | | - "calling this method requires that the type has a rank."); |
510 | | - } |
| 504 | +void mlir::PyShapedType::requireHasRank() { |
| 505 | + if (!mlirShapedTypeHasRank(*this)) { |
| 506 | + throw py::value_error( |
| 507 | + "calling this method requires that the type has a rank."); |
511 | 508 | } |
512 | | -}; |
| 509 | +} |
| 510 | + |
| 511 | +const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction = |
| 512 | + mlirTypeIsAShaped; |
| 513 | + |
| 514 | +namespace { |
513 | 515 |
|
514 | 516 | /// Vector Type subclass - VectorType. |
515 | 517 | class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> { |
|
0 commit comments