@@ -15,21 +15,19 @@ limitations under the License.
1515
1616// We need to keep some extra headers for the code in tpu_passes.h.inc.
1717
18- #include < cstdint>
1918#include < memory> // IWYU pragma: keep
2019#include < optional>
2120#include < string>
2221#include < string_view>
2322
24- #include " mlir/Dialect/Vector/IR/VectorOps.h"
2523#include " mlir/IR/BuiltinAttributes.h"
2624#include " mlir/IR/BuiltinOps.h"
25+ #include " mlir/Dialect/Vector/IR/VectorOps.h"
2726#include " mlir/IR/OperationSupport.h"
2827#include " mlir/IR/Value.h"
2928#include " mlir/IR/Visitors.h"
3029#include " mlir/Pass/Pass.h" // IWYU pragma: keep
3130#include " mlir/Support/LLVM.h"
32- #include " absl/strings/str_format.h"
3331#include " mlir/include/mlir/IR/BuiltinAttributes.h"
3432#include " mlir/include/mlir/IR/OpDefinition.h"
3533#include " mlir/include/mlir/IR/OperationSupport.h"
@@ -45,7 +43,7 @@ namespace {
4543
4644constexpr std::string_view kMangledDialect = " stable_mosaic." ;
4745constexpr StringRef kVersionAttrName = " stable_mosaic.version" ;
48- constexpr int kVersion = 4 ;
46+ constexpr int kVersion = 3 ;
4947
5048StringRef mangle (StringRef name, std::string* storage) {
5149 storage->clear ();
@@ -88,37 +86,21 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) {
8886
8987LogicalResult semaphore_signal_rule (Operation* op, int version) {
9088 // Added AttrSizedOperandSegments and core_id in version 2.
91- // Added subcore_id in version 4.
9289 if (version < 2 ) {
9390 if (op->getNumOperands () == 2 ) { // Local signal.
94- op->setAttr (
95- OpTrait::AttrSizedOperandSegments<
96- EnqueueDMAOp>::getOperandSegmentSizeAttr (),
97- mlir::DenseI32ArrayAttr::get (op->getContext (), {1 , 1 , 0 , 0 , 0 }));
91+ op->setAttr (OpTrait::AttrSizedOperandSegments<
92+ EnqueueDMAOp>::getOperandSegmentSizeAttr (),
93+ mlir::DenseI32ArrayAttr::get (op->getContext (), {1 , 1 , 0 , 0 }));
9894 } else if (op->getNumOperands () == 3 ) { // Remote signal.
99- op->setAttr (
100- OpTrait::AttrSizedOperandSegments<
101- EnqueueDMAOp>::getOperandSegmentSizeAttr (),
102- mlir::DenseI32ArrayAttr::get (op->getContext (), {1 , 1 , 1 , 0 , 0 }));
103- }
104- return op->emitError (" Unexpected operand count in tpu.semaphore_signal" );
105- } else if (version < 4 ) {
106- ArrayRef<int32_t > operand_segment_sizes =
107- op->getAttrOfType <DenseI32ArrayAttr>(
108- OpTrait::AttrSizedOperandSegments<
109- SemaphoreSignalOp>::getOperandSegmentSizeAttr ());
110- if (operand_segment_sizes.size () != 4 ) {
111- return op->emitError (absl::StrFormat (
112- " Expected operand count to be 4 in tpu.semaphore_signal. Got %d" ,
113- operand_segment_sizes.size ()));
95+ // Hardcoding that one optional value is device_id, not core_id. This
96+ // could misinterpret sem_signals where core_id is specified, but
97+ // device_id isn't.
98+ op->setAttr (OpTrait::AttrSizedOperandSegments<
99+ EnqueueDMAOp>::getOperandSegmentSizeAttr (),
100+ mlir::DenseI32ArrayAttr::get (op->getContext (), {1 , 1 , 1 , 0 }));
101+ } else {
102+ return op->emitError (" Unexpected operand count in tpu.semaphore_signal" );
114103 }
115- SmallVector<int32_t , 5 > new_operand_segment_sizes (
116- operand_segment_sizes.begin (), operand_segment_sizes.end ());
117- new_operand_segment_sizes.push_back (0 );
118- op->setAttr (OpTrait::AttrSizedOperandSegments<
119- EnqueueDMAOp>::getOperandSegmentSizeAttr (),
120- mlir::DenseI32ArrayAttr::get (op->getContext (),
121- new_operand_segment_sizes));
122104 }
123105 return success ();
124106}
0 commit comments