Skip to content

Commit 8993ece

Browse files
authored
Use isa for dialect matching instead of string comparison (#2757)
1 parent af8ba04 commit 8993ece

File tree

6 files changed

+9
-17
lines changed

6 files changed

+9
-17
lines changed

stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ bool allOperandsAreScalarTensors(Operation *op) {
178178

179179
bool isInBodyOfLinalgOps(Operation *op) {
180180
auto *parentOp = op->getParentRegion()->getParentOp();
181-
return parentOp->getDialect() ==
182-
parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>();
181+
return llvm::isa<linalg::LinalgDialect>(parentOp->getDialect());
183182
}
184183

185184
SmallVector<int64_t> extract1DVector(DenseIntElementsAttr elements) {

stablehlo/dialect/VhloOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ namespace {
5252
// Helper functions for VHLO verifiers
5353
template <typename TypeOrAttr>
5454
bool isFromVhlo(TypeOrAttr t) {
55-
return t.getDialect().getNamespace() == VhloDialect::getDialectNamespace();
55+
return llvm::isa<VhloDialect>(t.getDialect());
5656
}
5757

5858
template <typename TypeOrAttr>

stablehlo/dialect/VhloTypes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ namespace {
326326
// Helper functions for VHLO verifiers
327327
template <typename TypeOrAttr>
328328
bool isFromVhlo(TypeOrAttr t) {
329+
// Requires string comparison to avoid cyclic dependency with VhloOps.h.
329330
return t.getDialect().getNamespace() == "vhlo";
330331
}
331332

stablehlo/transforms/StablehloLegalizeToVhlo.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,8 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter {
5555
public:
5656
StablehloToVhloTypeConverter() : vhlo::VhloTypeConverter() {
5757
addConversion([](Type type) -> Type {
58-
if (type.getDialect().getNamespace() ==
59-
vhlo::VhloDialect::getDialectNamespace()) {
60-
return type;
61-
}
58+
if (llvm::isa<vhlo::VhloDialect>(type.getDialect())) return type;
59+
6260
LLVM_DEBUG(llvm::dbgs() << "Invalid type: " << type << '\n');
6361
return {};
6462
});
@@ -71,9 +69,7 @@ class StablehloToVhloTypeConverter : public vhlo::VhloTypeConverter {
7169
Attribute convertEncoding(Attribute attr) const final {
7270
LLVM_DEBUG(llvm::dbgs() << "Converting encoding.\n" << attr << '\n');
7371
// Must be VHLO encoding, or convertible to VHLO encoding.
74-
if (attr.getDialect().getNamespace() ==
75-
vhlo::VhloDialect::getDialectNamespace())
76-
return attr;
72+
if (llvm::isa<vhlo::VhloDialect>(attr.getDialect())) return attr;
7773

7874
if (auto stablehloAttr =
7975
dyn_cast_or_null<stablehlo::TypeExtensionsAttr>(attr)) {
@@ -141,8 +137,7 @@ Attribute convertGeneric(Attribute stablehloAttr,
141137
attr.getRtol(), attr.getUlps(),
142138
modeAttr);
143139
}
144-
if (stablehloAttr.getDialect().getNamespace() ==
145-
stablehlo::StablehloDialect::getDialectNamespace()) {
140+
if (llvm::isa<stablehlo::StablehloDialect>(stablehloAttr.getDialect())) {
146141
// All StableHLO attributes must have counterparts in VHLO.
147142
return {};
148143
}

stablehlo/transforms/VhloLegalizeToStablehlo.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,7 @@ Attribute convertGeneric(Attribute vhloAttr,
183183
}
184184

185185
// All VHLO Attributes must be converted by now.
186-
if (vhloAttr.getDialect().getNamespace() ==
187-
vhlo::VhloDialect::getDialectNamespace()) {
186+
if (llvm::isa<vhlo::VhloDialect>(vhloAttr.getDialect())) {
188187
// All VHLO attributes must have counterparts in StableHLO.
189188
return {};
190189
}

stablehlo/transforms/VhloToVersion.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ class VhloToVersionConverter : public TypeConverter {
6060
public:
6161
VhloToVersionConverter() : TypeConverter() {
6262
addConversion([](Type type) -> Type {
63-
if (type.getDialect().getNamespace() ==
64-
vhlo::VhloDialect::getDialectNamespace())
65-
return type;
63+
if (llvm::isa<vhlo::VhloDialect>(type.getDialect())) return type;
6664
LLVM_DEBUG(llvm::dbgs() << "Invalid type: " << type << '\n');
6765
return {};
6866
});

0 commit comments

Comments
 (0)