Skip to content

Commit 869a533

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Add bound check for general vector store op.
PiperOrigin-RevId: 698577015
1 parent 840cf3f commit 869a533

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,14 @@ void tpu_strided_store_rule(tpu::StridedStoreOp op) {
122122
/*strides=*/op.getStrides());
123123
}
124124

125+
void tpu_vector_store_rule(tpu::VectorStoreOp op) {
126+
// TODO(b/379925823): Take strides into account.
127+
assertIsValidSubwindow(
128+
op, op.getIndices(),
129+
/*window_shape=*/op.getValueToStore().getType().getShape(),
130+
/*full_shape=*/op.getBase().getType().getShape());
131+
}
132+
125133
const llvm::StringMap<rule_type> &rules() {
126134
static auto rules = new llvm::StringMap<rule_type>{
127135
// TODO: tpu::LoadOp, tpu::StoreOp
@@ -133,6 +141,8 @@ const llvm::StringMap<rule_type> &rules() {
133141
as_generic_rule(tpu_strided_load_rule)},
134142
{tpu::StridedStoreOp::getOperationName(),
135143
as_generic_rule(tpu_strided_store_rule)},
144+
{tpu::VectorStoreOp::getOperationName(),
145+
as_generic_rule(tpu_vector_store_rule)},
136146
};
137147
return *rules;
138148
}

0 commit comments

Comments
 (0)