@@ -56,6 +56,7 @@ using namespace mlir;
5656MLIRBench::MLIRBench (mlir::Operation *op, const MLIRBenchConfig &config)
5757 : builder(op->getContext ()), unkLoc(builder.getUnknownLoc()) {
5858 seed = config.seed ;
59+ identity = config.identity ;
5960 backend = config.backend ;
6061 initType = config.initType ;
6162 offloadToDevice = config.offloadToDevice ;
@@ -113,7 +114,7 @@ LogicalResult MLIRBench::replaceSplatWithRandom() {
113114 return module .emitError (" No seed for random init" );
114115
115116 // Only replace attribute if it's a dense splat
116- auto replaceSplat = [&](ShapedType shape, Attribute attr) -> Attribute {
117+ auto replaceSplat = [&](ShapedType shape, Attribute attr) -> FailureOr< Attribute> {
117118 // We only change dense attributes that are splat
118119 auto value = dyn_cast<DenseElementsAttr>(attr);
119120 if (!value || !value.isSplat ())
@@ -145,7 +146,9 @@ LogicalResult MLIRBench::replaceSplatWithRandom() {
145146 if (!global)
146147 continue ;
147148 auto newAttr = replaceSplat (global.getType (), global.getInitialValueAttr ());
148- global.setInitialValueAttr (newAttr);
149+ if (failed (newAttr))
150+ return failure ();
151+ global.setInitialValueAttr (newAttr.value ());
149152 }
150153
151154 // Tensors are arith.constant values
@@ -157,7 +160,9 @@ LogicalResult MLIRBench::replaceSplatWithRandom() {
157160 if (!cstType)
158161 continue ;
159162 auto newAttr = replaceSplat (cstType, constant.getValueAttr ());
160- constant.setValueAttr (cast<TypedAttr>(newAttr));
163+ if (failed (newAttr))
164+ return failure ();
165+ constant.setValueAttr (cast<TypedAttr>(newAttr.value ()));
161166 }
162167
163168 return success ();
@@ -212,34 +217,48 @@ LogicalResult MLIRBench::createKernelArgs() {
212217 auto &mainBody = getMainBlock ();
213218 builder.setInsertionPointToStart (&mainBody);
214219
220+ int argNum = 0 ;
215221 for (auto &ty : kernel.getArgumentTypes ()) {
216- auto arg = TypeSwitch<Type, std::optional<Value>>(ty)
217- .Case <MemRefType>([&](auto memRefTy) {
218- // Create a memref global
219- Value data = createDenseMemref (builder, module , initType,
220- memRefTy, seed);
221- data = registerOnGpu (data, memRefTy);
222- return data;
223- })
224- .Case <TensorType>([&](auto tensorTy) {
225- // Create a memref global and cast it to a tensor
226- // to ensure that the buffer is writable and
227- // bufferization does not insert extra
228- // allocations + copies
229- auto memrefType = MemRefType::get (
230- tensorTy.getShape (), tensorTy.getElementType ());
231- auto data = createDenseMemref (builder, module , initType,
232- memrefType, seed);
233- data = registerOnGpu (data, memrefType);
234- return builder.create <bufferization::ToTensorOp>(
235- unkLoc, tensorTy, data, /* restrict=*/ true , /* writable=*/ true );
236- })
237- .Default ([&](auto t) { return std::nullopt ; });
222+ auto argInitType = initType;
223+ // Requested an argument to be identity, must be 2D square
224+ if (argNum == identity) {
225+ ShapedType shape = dyn_cast<ShapedType>(ty);
226+ if (shape && shape.getRank () == 2 &&
227+ shape.getDimSize (0 ) == shape.getDimSize (1 )) {
228+ argInitType = TensorInitType::Identity;
229+ } else {
230+ return module .emitError (" Invalid shape for identity init" );
231+ }
232+ }
233+ auto arg =
234+ TypeSwitch<Type, std::optional<Value>>(ty)
235+ .Case <MemRefType>([&](auto memRefTy) {
236+ // Create a memref global
237+ Value data = createDenseMemref (builder, module , argInitType,
238+ memRefTy, seed);
239+ data = registerOnGpu (data, memRefTy);
240+ return data;
241+ })
242+ .Case <TensorType>([&](auto tensorTy) {
243+ // Create a memref global and cast it to a tensor
244+ // to ensure that the buffer is writable and
245+ // bufferization does not insert extra
246+ // allocations + copies
247+ auto memrefType = MemRefType::get (tensorTy.getShape (),
248+ tensorTy.getElementType ());
249+ auto data = createDenseMemref (builder, module , argInitType,
250+ memrefType, seed);
251+ data = registerOnGpu (data, memrefType);
252+ return builder.create <bufferization::ToTensorOp>(
253+ unkLoc, tensorTy, data, /* restrict=*/ true , /* writable=*/ true );
254+ })
255+ .Default ([&](auto t) { return std::nullopt ; });
238256
239257 if (!arg)
240- return failure ( );
258+ return module . emitError ( " Cannot create kernel argument " );
241259
242260 kernelArgs.push_back (*arg);
261+ argNum++;
243262 }
244263
245264 return success ();
0 commit comments