@@ -55,10 +55,8 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter {
5555 public:
5656 StablehloToVhloTypeConverter () : vhlo::VhloTypeConverter() {
5757 addConversion ([](Type type) -> Type {
58- if (type.getDialect ().getNamespace () ==
59- vhlo::VhloDialect::getDialectNamespace ()) {
60- return type;
61- }
58+ if (llvm::isa<vhlo::VhloDialect>(type.getDialect ())) return type;
59+
6260 LLVM_DEBUG (llvm::dbgs () << " Invalid type: " << type << ' \n ' );
6361 return {};
6462 });
@@ -71,9 +69,7 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter {
7169 Attribute convertEncoding (Attribute attr) const final {
7270 LLVM_DEBUG (llvm::dbgs () << " Converting encoding.\n " << attr << ' \n ' );
7371 // Must be VHLO encoding, or convertible to VHLO encoding.
74- if (attr.getDialect ().getNamespace () ==
75- vhlo::VhloDialect::getDialectNamespace ())
76- return attr;
72+ if (llvm::isa<vhlo::VhloDialect>(attr.getDialect ())) return attr;
7773
7874 if (auto stablehloAttr =
7975 dyn_cast_or_null<stablehlo::TypeExtensionsAttr>(attr)) {
@@ -141,8 +137,7 @@ Attribute convertGeneric(Attribute stablehloAttr,
141137 attr.getRtol (), attr.getUlps (),
142138 modeAttr);
143139 }
144- if (stablehloAttr.getDialect ().getNamespace () ==
145- stablehlo::StablehloDialect::getDialectNamespace ()) {
140+ if (llvm::isa<stablehlo::StablehloDialect>(stablehloAttr.getDialect ())) {
146141 // All StableHLO attributes must have counterparts in VHLO.
147142 return {};
148143 }
0 commit comments