Skip to content

Commit 9dbeb83

Browse files
authored
[water] enforce uniqueness of symbolic tensor dimensions (#776)
Symbols indicate how dimensions are indexed. Repeating the symbol means we would be indexing two dimensions identically, i.e., their indexes are always equal, which is not desirable.
1 parent 27a54c9 commit 9dbeb83

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

water/include/water/Dialect/Wave/IR/WaveTypes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ def WaveTensorType : TypeDef<WaveDialect, "WaveTensor"> {
2323
unknown in the earlier stages and will be inferred later. Tensors may have
2424
an address space indicating whether the data is expected to live in a
2525
certain location, including local data store (shared memory) and registers.
26+
27+
When the shape is specified, symbols used in its dimensions indicate which
28+
"logical iterators" will be used to index into the tensor. Therefore, all
29+
symbols must be unique. (If the same symbol were used twice, it would mean
30+
that only a diagonal of the tensor can be indexed since two or more
31+
dimensions would be co-indexed).
2632
}];
2733

2834
let parameters = (ins

water/lib/Dialect/Wave/IR/WaveTypes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ wave::WaveTensorType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
8181
return emitError() << "expected element type to be integer, index or "
8282
"floating point scalar";
8383
}
84+
llvm::SmallPtrSet<Attribute, 6> seenSymbols;
85+
for (auto symbol : shape) {
86+
if (!seenSymbols.insert(symbol).second) {
87+
return emitError() << "duplicate symbol " << symbol << " in shape";
88+
}
89+
}
8490
return success();
8591
}
8692

water/test/Dialect/Wave/attr-type-invalid.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,8 @@ module attributes {wave_test.symbol = #wave.symbol<"_A">}
127127
step = affine_map<()[s0, s1] -> (s0)>,
128128
stride = affine_map<()[s0, s1] -> (s0)>
129129
} : () -> ()
130+
131+
// -----
132+
133+
// expected-error @below {{duplicate symbol #wave.symbol<"A"> in shape}}
134+
"wave_test.create_tensor"() {fully_specified = true, shape = [@A, @B, @A]} : () -> ()

0 commit comments

Comments
 (0)