Skip to content

Commit c81c97f

Browse files
committed
Allow access to $_builder in Attr's defaultValue
Reverts changes to DefaultValued(Optional)Attr.
1 parent 538566a commit c81c97f

File tree

5 files changed

+47
-27
lines changed

5 files changed

+47
-27
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -606,10 +606,10 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
606606
let arguments = (ins
607607
Variadic<AnyType>:$inputs,
608608
Variadic<AnyShaped>:$outputs,
609-
DefaultValuedOptionalAttr<AffineMapArrayAttr, "SmallVector<AffineMap>()",
610-
builderCall = [{ $_builder.getAffineMapArrayAttr(
611-
$0.empty() ? MatmulOp::getDefaultIndexingMaps($_builder.getContext()) : $0
612-
)}]>:$indexing_maps,
609+
DefaultValuedOptionalAttr<
610+
AffineMapArrayAttr,
611+
"MatmulOp::getDefaultIndexingMaps($_builder.getContext())"
612+
>:$indexing_maps,
613613
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
614614
);
615615
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ class Attr<Pred condition, string summary = ""> :
5050

5151
// Default value for attribute.
5252
// Requires a constBuilderCall defined.
53+
//
54+
// Format: `$_builder` will be expanded to the relevant builder, e.g. to allow
55+
// access to the current context.
5356
string defaultValue = ?;
5457

5558
// The value type of this attribute. This corresponds to the mlir::Type that
@@ -90,15 +93,15 @@ class DialectAttr<Dialect d, Pred condition, string summary = ""> :
9093
// Attribute modifier definition
9194

9295
// Decorates an attribute to have an (unvalidated) default value if not present.
93-
class DefaultValuedAttr<Attr attr, string val, string builderCall = ""> :
96+
class DefaultValuedAttr<Attr attr, string val> :
9497
Attr<attr.predicate, attr.summary> {
9598
// Construct this attribute with the input attribute and change only
9699
// the default value.
97100
// Note: this has to be kept up to date with Attr above.
98101
let storageType = attr.storageType;
99102
let returnType = attr.returnType;
100103
let convertFromStorage = attr.convertFromStorage;
101-
let constBuilderCall = !if(!eq(builderCall, ""), attr.constBuilderCall, builderCall);
104+
let constBuilderCall = attr.constBuilderCall;
102105
let defaultValue = val;
103106
let valueType = attr.valueType;
104107

@@ -107,15 +110,15 @@ class DefaultValuedAttr<Attr attr, string val, string builderCall = ""> :
107110

108111
// Decorates an optional attribute to have an (unvalidated) default value
109112
// return by ODS generated accessors if not present.
110-
class DefaultValuedOptionalAttr<Attr attr, string val, string builderCall = ""> :
113+
class DefaultValuedOptionalAttr<Attr attr, string val> :
111114
Attr<attr.predicate, attr.summary> {
112115
// Construct this attribute with the input attribute and change only
113116
// the default value.
114117
// Note: this has to be kept up to date with Attr above.
115118
let storageType = attr.storageType;
116119
let returnType = attr.returnType;
117120
let convertFromStorage = attr.convertFromStorage;
118-
let constBuilderCall = !if(!eq(builderCall, ""), attr.constBuilderCall, builderCall);
121+
let constBuilderCall = attr.constBuilderCall;
119122
let defaultValue = val;
120123
let valueType = attr.valueType;
121124
let isOptional = 1;

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,8 +1334,9 @@ static void emitAttrGetterWithReturnType(FmtContext &fctx,
13341334
PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() +
13351335
" must have a constBuilder");
13361336
}
1337-
std::string defaultValue = std::string(
1338-
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1337+
std::string defaultValue =
1338+
std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
1339+
tgfmt(attr.getDefaultValue(), &fctx)));
13391340
body << " if (!attr)\n return "
13401341
<< tgfmt(attr.getConvertFromStorageCall(),
13411342
&fctx.withSelf(defaultValue))
@@ -1467,6 +1468,7 @@ void OpEmitter::genPropertiesSupport() {
14671468
os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
14681469
}
14691470

1471+
fctx.withBuilder(odsBuilder);
14701472
setPropMethod << "{\n"
14711473
<< formatv(propFromAttrFmt,
14721474
tgfmt(prop.getConvertFromAttributeCall(),
@@ -1479,7 +1481,7 @@ void OpEmitter::genPropertiesSupport() {
14791481
prop.getStorageTypeValueOverride());
14801482
} else if (prop.hasDefaultValue()) {
14811483
setPropMethod << formatv(attrGetDefaultFmt, name,
1482-
prop.getDefaultValue());
1484+
tgfmt(prop.getDefaultValue(), &fctx));
14831485
} else {
14841486
setPropMethod << formatv(attrGetNoDefaultFmt, name);
14851487
}
@@ -2919,6 +2921,9 @@ getBuilderSignature(const Builder &builder) {
29192921
arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
29202922
arguments.emplace_back("::mlir::OperationState &", builderOpState);
29212923

2924+
FmtContext fctx;
2925+
fctx.withBuilder(odsBuilder);
2926+
29222927
for (unsigned i = 0, e = params.size(); i < e; ++i) {
29232928
// If no name is provided, generate one.
29242929
std::optional<StringRef> paramName = params[i].getName();
@@ -2931,7 +2936,7 @@ getBuilderSignature(const Builder &builder) {
29312936
defaultValue = *defaultParamValue;
29322937

29332938
arguments.emplace_back(params[i].getCppType(), std::move(name),
2934-
defaultValue);
2939+
tgfmt(defaultValue, &fctx));
29352940
}
29362941

29372942
return arguments;
@@ -3189,6 +3194,9 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
31893194
}
31903195
}
31913196

3197+
FmtContext fctx;
3198+
fctx.withBuilder(odsBuilder);
3199+
31923200
for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
31933201
Argument arg = op.getArg(i);
31943202
if (const auto *operand =
@@ -3210,7 +3218,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
32103218
StringRef type = prop.getInterfaceType();
32113219
std::string defaultValue;
32123220
if (prop.hasDefaultValue() && i >= defaultValuedAttrLikeStartIndex) {
3213-
defaultValue = prop.getDefaultValue();
3221+
defaultValue = tgfmt(prop.getDefaultValue(), &fctx);
32143222
}
32153223
bool isOptional = prop.hasDefaultValue();
32163224
paramList.emplace_back(type, propArg->name, StringRef(defaultValue),
@@ -3242,7 +3250,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
32423250
if (i >= defaultValuedAttrStartIndex) {
32433251
if (attrParamKind == AttrParamKind::UnwrappedValue &&
32443252
canUseUnwrappedRawValue(attr))
3245-
defaultValue += attr.getDefaultValue();
3253+
defaultValue += tgfmt(attr.getDefaultValue(), &fctx);
32463254
else
32473255
defaultValue += "nullptr";
32483256
}
@@ -4172,6 +4180,9 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
41724180
staticVerifierEmitter(staticVerifierEmitter),
41734181
emitHelper(op, /*emitForOp=*/false) {
41744182

4183+
FmtContext fctx;
4184+
fctx.withBuilder(odsBuilder);
4185+
41754186
genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Public);
41764187
bool useProperties = emitHelper.hasProperties();
41774188
if (useProperties) {
@@ -4212,7 +4223,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
42124223
if (prop.hasStorageTypeValueOverride())
42134224
os << " = " << prop.getStorageTypeValueOverride();
42144225
else if (prop.hasDefaultValue())
4215-
os << " = " << prop.getDefaultValue();
4226+
os << " = " << tgfmt(prop.getDefaultValue(), &fctx);
42164227
comparatorOs << " rhs." << name << " == this->" << name
42174228
<< " &&\n";
42184229
// Emit accessors using the interface type.
@@ -4454,7 +4465,6 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
44544465
if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands"))
44554466
m->body() << " return odsOperands;";
44564467

4457-
FmtContext fctx;
44584468
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
44594469

44604470
// Generate named accessor with Attribute return type.
@@ -4481,8 +4491,9 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
44814491
// Use the default value if attribute is not set.
44824492
// TODO: this is inefficient, we are recreating the attribute for every
44834493
// call. This should be set instead.
4484-
std::string defaultValue = std::string(
4485-
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
4494+
std::string defaultValue =
4495+
std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
4496+
tgfmt(attr.getDefaultValue(), &fctx)));
44864497
body << "if (!attr)\n attr = " << defaultValue << ";\n";
44874498
}
44884499
body << "return attr;\n";

mlir/tools/mlir-tblgen/OpFormatGen.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,16 +1999,18 @@ static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
19991999
fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
20002000
body << getter << "Attr() != "
20012001
<< tgfmt(attr.getConstBuilderTemplate(), &fctx,
2002-
attr.getDefaultValue());
2002+
tgfmt(attr.getDefaultValue(), &fctx));
20032003
}
20042004
if (optionalAndDefault)
20052005
body << ")";
20062006
}
20072007

20082008
static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
20092009
PropertyVariable &propElement) {
2010-
body << op.getGetterName(propElement.getVar()->name)
2011-
<< "() != " << propElement.getVar()->prop.getDefaultValue();
2010+
FmtContext fctx;
2011+
fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
2012+
body << op.getGetterName(propElement.getVar()->name) << "() != "
2013+
<< tgfmt(propElement.getVar()->prop.getDefaultValue(), &fctx);
20122014
}
20132015

20142016
/// Elide the variadic segment size attributes if necessary.
@@ -2045,8 +2047,9 @@ static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
20452047
const StringRef &name = namedAttr.name;
20462048
FmtContext fctx;
20472049
fctx.withBuilder("odsBuilder");
2048-
std::string defaultValue = std::string(
2049-
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2050+
std::string defaultValue =
2051+
std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
2052+
tgfmt(attr.getDefaultValue(), &fctx)));
20502053
body << " {\n";
20512054
body << " ::mlir::Builder odsBuilder(getContext());\n";
20522055
body << " ::mlir::Attribute attr = " << op.getGetterName(name)
@@ -2059,8 +2062,10 @@ static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
20592062
// Similarly, elide default-valued properties.
20602063
for (const NamedProperty &prop : op.getProperties()) {
20612064
if (prop.prop.hasDefaultValue()) {
2065+
FmtContext fctx;
2066+
fctx.withBuilder("odsBuilder");
20622067
body << " if (" << op.getGetterName(prop.name)
2063-
<< "() == " << prop.prop.getDefaultValue() << ") {";
2068+
<< "() == " << tgfmt(prop.prop.getDefaultValue(), &fctx) << ") {";
20642069
body << " elidedProps.push_back(\"" << prop.name << "\");\n";
20652070
body << " }\n";
20662071
}
@@ -2094,8 +2099,9 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
20942099
const StringRef &name = namedAttr.name;
20952100
FmtContext fctx;
20962101
fctx.withBuilder("odsBuilder");
2097-
std::string defaultValue = std::string(
2098-
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2102+
std::string defaultValue =
2103+
std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
2104+
tgfmt(attr.getDefaultValue(), &fctx)));
20992105
body << " {\n";
21002106
body << " ::mlir::Builder odsBuilder(getContext());\n";
21012107
body << " ::mlir::Attribute attr = " << op.getGetterName(name)

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
879879
if (attr.hasDefaultValue()) {
880880
os << "if (!tblgen_attr) tblgen_attr = "
881881
<< std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
882-
attr.getDefaultValue()))
882+
tgfmt(attr.getDefaultValue(), &fmtCtx)))
883883
<< ";\n";
884884
} else if (attr.isOptional()) {
885885
// For a missing attribute that is optional according to definition, we

0 commit comments

Comments
 (0)