Skip to content

Commit 5053f55

Browse files
authored
[Init] Add identity initializer (#1079)
This is a step towards allowing tests like A * I = A. As is, it still doesn't do what we want, since like the others, it tried to initialize all tensors with identity, so any non square tensor will fail. The next step is to add options like "--identity-b" together with "--splat-to-random", so that only the B matrix is replaced by identity. Also remove some unused initializers.
1 parent c859651 commit 5053f55

File tree

18 files changed

+246
-196
lines changed

18 files changed

+246
-196
lines changed

include/TPP/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,9 @@ def TppRunnerWrapper : Pass<"tpp-runner-wrapper", "ModuleOp">{
577577
Option<"initType", "init-type", "std::string",
578578
/*default=*/"",
579579
"Initializer type (const, simple, cont, rand, normal).">,
580+
Option<"identity", "identity", "int64_t",
581+
/*default=*/"-1",
582+
"Identity square argument (-1=none, 0=a, 1=b, ...).">,
580583
];
581584
}
582585

include/TPP/Runner/MLIRBench.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,14 @@ class FuncOp;
4040
// pipeline.
4141
struct MLIRBenchConfig {
4242
MLIRBenchConfig() = default;
43-
MLIRBenchConfig(int seed, TensorInitType initType, std::string backend,
43+
MLIRBenchConfig(int seed, TensorInitType initType, int identity, std::string backend,
4444
bool offloadToDevice)
45-
: seed(seed), initType(initType), backend(backend),
45+
: seed(seed), initType(initType), identity(identity), backend(backend),
4646
offloadToDevice(offloadToDevice) {}
4747

4848
int seed = 0;
4949
TensorInitType initType = TensorInitType::Auto;
50+
int identity = -1;
5051
std::string backend = "cpu";
5152
bool offloadToDevice = true;
5253
};
@@ -85,6 +86,9 @@ class MLIRBench {
8586
/// Seed for the random tensor filling
8687
int seed;
8788

89+
/// Which argument is the identity, if any
90+
int identity;
91+
8892
/// Tensor init type
8993
TensorInitType initType;
9094

include/TPP/Transforms/Utils/TensorInit.h

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ struct ITensorInit {
2828
// Returns a dense attribute with a specified shape, initialized
2929
// with a particular implementation (see derived classes) with
3030
// a reasonable distribution.
31-
virtual mlir::DenseElementsAttr get(mlir::ShapedType shape) = 0;
31+
virtual llvm::FailureOr<mlir::DenseElementsAttr>
32+
get(mlir::ShapedType shape) = 0;
3233
};
3334

3435
// Base class.
@@ -39,12 +40,15 @@ template <typename T> struct TensorInit : public ITensorInit {
3940
// Returns a dense attribute with a specified shape, initialized
4041
// with a particular implementation (see derived classes) with
4142
// a reasonable distribution.
42-
virtual mlir::DenseElementsAttr get(mlir::ShapedType shape) override {
43+
virtual llvm::FailureOr<mlir::DenseElementsAttr>
44+
get(mlir::ShapedType shape) override {
45+
if (!checkShape(shape))
46+
return llvm::failure();
47+
48+
// Populate the shape
4349
buffer.clear();
44-
size = 1;
45-
for (size_t dim = 0, rank = shape.getRank(); dim < rank; dim++)
46-
size *= shape.getDimSize(dim);
4750
fillData();
51+
4852
// For some reason, memref global op needs dense tensor type
4953
// See: lib/Dialect/MemRef/IR/MemRefOps.cpp :: GlobalOp::verify
5054
auto tensorType =
@@ -53,11 +57,25 @@ template <typename T> struct TensorInit : public ITensorInit {
5357
}
5458

5559
protected:
60+
// Shape dims
61+
std::vector<size_t> dims;
5662
// Number of elements in the shape
5763
size_t size;
5864
// Data pointer
5965
std::vector<T> buffer;
6066

67+
// Check the shape and fill the internal structure
68+
virtual bool checkShape(mlir::ShapedType shape) {
69+
size = 1;
70+
dims.clear();
71+
for (size_t i = 0, rank = shape.getRank(); i < rank; i++) {
72+
auto dim = shape.getDimSize(i);
73+
dims.push_back(dim);
74+
size *= dim;
75+
}
76+
return true;
77+
}
78+
6179
// Insert element indexed on the buffer
6280
virtual void insert(size_t index, T value) {
6381
buffer[index] = value;
@@ -82,10 +100,9 @@ template <typename T> struct TensorInit : public ITensorInit {
82100
enum class TensorInitType {
83101
Auto,
84102
Constant,
85-
Simple,
86-
Continuous,
87103
Random,
88104
Normal,
105+
Identity,
89106
Invalid
90107
};
91108

include/TPP/Transforms/Utils/TensorInitFloat.h

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,30 +85,14 @@ struct TensorInitFloat : public TensorInit<llvm::APFloat> {
8585
virtual void fillData() override = 0;
8686
};
8787

88-
// Constant init (all-ones, do not use!).
88+
// Constant init (all-ones).
8989
struct ConstantTensorInitFloat : TensorInitFloat {
9090
ConstantTensorInitFloat(DataType type) : TensorInitFloat(type) {}
9191

9292
// Return a dense<1.0> repeated throughout the shape.
93-
mlir::DenseElementsAttr get(mlir::ShapedType shape) override;
93+
mlir::FailureOr<mlir::DenseElementsAttr> get(mlir::ShapedType shape) override;
9494

95-
void fillData() override;
96-
};
97-
98-
// Simple init (basic example, not useful).
99-
struct SimpleTensorInitFloat : TensorInitFloat {
100-
SimpleTensorInitFloat(DataType type) : TensorInitFloat(type) {}
101-
102-
// Return a dense<0.3, 0.6, 0.9> repeated throughout the shape.
103-
void fillData() override;
104-
};
105-
106-
// Continuous init (normalized affine range).
107-
struct ContinuousTensorInitFloat : TensorInitFloat {
108-
ContinuousTensorInitFloat(DataType type) : TensorInitFloat(type) {}
109-
110-
// Return a dense<0.0 ... 1.0> throughout the shape.
111-
void fillData() override;
95+
void fillData() override { assert(false && "Should not be called"); }
11296
};
11397

11498
// Random init (uniform).
@@ -151,4 +135,24 @@ struct NormalTensorInitFloat : TensorInitFloat {
151135
std::normal_distribution<float> distribution;
152136
};
153137

138+
// Identity init.
139+
struct IdentityTensorInitFloat : TensorInitFloat {
140+
IdentityTensorInitFloat(DataType type)
141+
: TensorInitFloat(type) {}
142+
143+
// Makes sure the shape is "square"
144+
bool checkShape(mlir::ShapedType shape) override {
145+
if (!TensorInit::checkShape(shape))
146+
return false;
147+
// Now the fields are set, compare all dims to be equal, 2D only for now
148+
return dims.size() == 2 && dims[0] == dims[1];
149+
}
150+
151+
// Should not be called.
152+
float next() { assert(false && "Should not be called"); }
153+
154+
// Return a diagonal of <1.0>s throughout the shape.
155+
void fillData() override;
156+
};
157+
154158
#endif // TPP_TRANSFORMS_UTILS_TENSORINITFLOAT_H

include/TPP/Transforms/Utils/TensorInitInt.h

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,37 +74,14 @@ struct TensorInitInt : public TensorInit<llvm::APInt> {
7474
virtual void fillData() override = 0;
7575
};
7676

77-
// Constant init (all-ones, do not use!).
77+
// Constant init (all-ones).
7878
struct ConstantTensorInitInt : TensorInitInt {
7979
ConstantTensorInitInt(DataType type) : TensorInitInt(type) {}
8080

8181
// Return a dense<1> repeated throughout the shape.
82-
mlir::DenseElementsAttr get(mlir::ShapedType shape) override;
82+
mlir::FailureOr<mlir::DenseElementsAttr> get(mlir::ShapedType shape) override;
8383

84-
void fillData() override;
85-
};
86-
87-
// Simple init (basic example, not useful).
88-
struct SimpleTensorInitInt : TensorInitInt {
89-
SimpleTensorInitInt(DataType type) : TensorInitInt(type) {}
90-
91-
// Return a dense<0, 1, 2> repeated throughout the shape.
92-
void fillData() override;
93-
};
94-
95-
// Continuous init (quantized normalized affine range).
96-
struct ContinuousTensorInitInt : TensorInitInt {
97-
ContinuousTensorInitInt(DataType type)
98-
: TensorInitInt(type), upperBound(255) {
99-
if (type == DataType::I8)
100-
upperBound = 127;
101-
}
102-
103-
// Return a dense<0 ... upperBound> throughout the shape.
104-
void fillData() override;
105-
106-
// Upper bound for quantization.
107-
int upperBound;
84+
void fillData() override { assert(false && "Should not be called"); }
10885
};
10986

11087
// Random init (uniform).
@@ -152,4 +129,24 @@ struct NormalTensorInitInt : TensorInitInt {
152129
std::binomial_distribution<uint64_t> distribution;
153130
};
154131

132+
// Identity init.
133+
struct IdentityTensorInitInt : TensorInitInt {
134+
IdentityTensorInitInt(DataType type)
135+
: TensorInitInt(type) {}
136+
137+
// Makes sure the shape is "square"
138+
bool checkShape(mlir::ShapedType shape) override {
139+
if (!TensorInit::checkShape(shape))
140+
return false;
141+
// Now the fields are set, compare all dims to be equal, 2D only for now
142+
return dims.size() == 2 && dims[0] == dims[1];
143+
}
144+
145+
// Should not be called.
146+
float next() { assert(false && "Should not be called"); }
147+
148+
// Return a diagonal of <1.0>s throughout the shape.
149+
void fillData() override;
150+
};
151+
155152
#endif // TPP_TRANSFORMS_UTILS_TENSORINITINT_H

lib/TPP/Runner/MLIRBench.cpp

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ using namespace mlir;
5656
MLIRBench::MLIRBench(mlir::Operation *op, const MLIRBenchConfig &config)
5757
: builder(op->getContext()), unkLoc(builder.getUnknownLoc()) {
5858
seed = config.seed;
59+
identity = config.identity;
5960
backend = config.backend;
6061
initType = config.initType;
6162
offloadToDevice = config.offloadToDevice;
@@ -113,7 +114,7 @@ LogicalResult MLIRBench::replaceSplatWithRandom() {
113114
return module.emitError("No seed for random init");
114115

115116
// Only replace attribute if it's a dense splat
116-
auto replaceSplat = [&](ShapedType shape, Attribute attr) -> Attribute {
117+
auto replaceSplat = [&](ShapedType shape, Attribute attr) -> FailureOr<Attribute> {
117118
// We only change dense attributes that are splat
118119
auto value = dyn_cast<DenseElementsAttr>(attr);
119120
if (!value || !value.isSplat())
@@ -145,7 +146,9 @@ LogicalResult MLIRBench::replaceSplatWithRandom() {
145146
if (!global)
146147
continue;
147148
auto newAttr = replaceSplat(global.getType(), global.getInitialValueAttr());
148-
global.setInitialValueAttr(newAttr);
149+
if (failed(newAttr))
150+
return failure();
151+
global.setInitialValueAttr(newAttr.value());
149152
}
150153

151154
// Tensors are arith.constant values
@@ -157,7 +160,9 @@ LogicalResult MLIRBench::replaceSplatWithRandom() {
157160
if (!cstType)
158161
continue;
159162
auto newAttr = replaceSplat(cstType, constant.getValueAttr());
160-
constant.setValueAttr(cast<TypedAttr>(newAttr));
163+
if (failed(newAttr))
164+
return failure();
165+
constant.setValueAttr(cast<TypedAttr>(newAttr.value()));
161166
}
162167

163168
return success();
@@ -212,34 +217,48 @@ LogicalResult MLIRBench::createKernelArgs() {
212217
auto &mainBody = getMainBlock();
213218
builder.setInsertionPointToStart(&mainBody);
214219

220+
int argNum = 0;
215221
for (auto &ty : kernel.getArgumentTypes()) {
216-
auto arg = TypeSwitch<Type, std::optional<Value>>(ty)
217-
.Case<MemRefType>([&](auto memRefTy) {
218-
// Create a memref global
219-
Value data = createDenseMemref(builder, module, initType,
220-
memRefTy, seed);
221-
data = registerOnGpu(data, memRefTy);
222-
return data;
223-
})
224-
.Case<TensorType>([&](auto tensorTy) {
225-
// Create a memref global and cast it to a tensor
226-
// to ensure that the buffer is writable and
227-
// bufferization does not insert extra
228-
// allocations + copies
229-
auto memrefType = MemRefType::get(
230-
tensorTy.getShape(), tensorTy.getElementType());
231-
auto data = createDenseMemref(builder, module, initType,
232-
memrefType, seed);
233-
data = registerOnGpu(data, memrefType);
234-
return builder.create<bufferization::ToTensorOp>(
235-
unkLoc, tensorTy, data, /*restrict=*/true, /*writable=*/true);
236-
})
237-
.Default([&](auto t) { return std::nullopt; });
222+
auto argInitType = initType;
223+
// Requested an argument to be identity, must be 2D square
224+
if (argNum == identity) {
225+
ShapedType shape = dyn_cast<ShapedType>(ty);
226+
if (shape && shape.getRank() == 2 &&
227+
shape.getDimSize(0) == shape.getDimSize(1)) {
228+
argInitType = TensorInitType::Identity;
229+
} else {
230+
return module.emitError("Invalid shape for identity init");
231+
}
232+
}
233+
auto arg =
234+
TypeSwitch<Type, std::optional<Value>>(ty)
235+
.Case<MemRefType>([&](auto memRefTy) {
236+
// Create a memref global
237+
Value data = createDenseMemref(builder, module, argInitType,
238+
memRefTy, seed);
239+
data = registerOnGpu(data, memRefTy);
240+
return data;
241+
})
242+
.Case<TensorType>([&](auto tensorTy) {
243+
// Create a memref global and cast it to a tensor
244+
// to ensure that the buffer is writable and
245+
// bufferization does not insert extra
246+
// allocations + copies
247+
auto memrefType = MemRefType::get(tensorTy.getShape(),
248+
tensorTy.getElementType());
249+
auto data = createDenseMemref(builder, module, argInitType,
250+
memrefType, seed);
251+
data = registerOnGpu(data, memrefType);
252+
return builder.create<bufferization::ToTensorOp>(
253+
unkLoc, tensorTy, data, /*restrict=*/true, /*writable=*/true);
254+
})
255+
.Default([&](auto t) { return std::nullopt; });
238256

239257
if (!arg)
240-
return failure();
258+
return module.emitError("Cannot create kernel argument");
241259

242260
kernelArgs.push_back(*arg);
261+
argNum++;
243262
}
244263

245264
return success();

lib/TPP/Runner/TppRunnerWrapper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct TppRunnerWrapper
5656
}
5757

5858
// Benchmark object.
59-
MLIRBenchConfig config(seed, tensorInitType, backend, offloadToDevice);
59+
MLIRBenchConfig config(seed, tensorInitType, identity, backend, offloadToDevice);
6060
MLIRBench bench(module, config);
6161

6262
// Can only either print or run benchmarks, make this clear before we try to

lib/TPP/Transforms/Utils/BuilderUtils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ Value createDenseTensor(OpBuilder &builder, TensorInitType initType,
7979
auto unkLoc = builder.getUnknownLoc();
8080
auto init = getTensorInit(initType, type.getElementType(), seed);
8181
auto floatInit = init->get(type);
82-
return builder.create<arith::ConstantOp>(unkLoc, type, floatInit);
82+
assert(!failed(floatInit) && "Invalid dense tensor initializer");
83+
return builder.create<arith::ConstantOp>(unkLoc, type, floatInit.value());
8384
}
8485

8586
Value createDenseMemref(OpBuilder &builder, ModuleOp module,
@@ -103,10 +104,11 @@ Value createDenseMemref(OpBuilder &builder, ModuleOp module,
103104
auto alignment = builder.getIntegerAttr(builder.getI64Type(), 128);
104105
auto init = getTensorInit(initType, type.getElementType(), seed);
105106
auto floatInit = init->get(type);
107+
assert(!failed(floatInit) && "Invalid dense tensor initializer");
106108

107109
// Create the global object in the Module's region
108110
auto global = builder.create<memref::GlobalOp>(
109-
unkLoc, StringRef(name), privAttr, type, floatInit,
111+
unkLoc, StringRef(name), privAttr, type, floatInit.value(),
110112
/*constant=*/false, alignment);
111113
globalName = global.getName();
112114
}

0 commit comments

Comments
 (0)