Skip to content

Commit 5e8cbb7

Browse files
authored
Update constraints to torch 2.4 dynamic_shapes API (#806)
This commit updates to use dynamic_shapes for dynamic dimensions as the usage of constraints is deprecated in torch 2.4. (The SHARK test checks out the main branch of this repo and will pass once this is merged)
1 parent 26ce08e commit 5e8cbb7

File tree

2 files changed

+24
-21
lines changed

2 files changed

+24
-21
lines changed

models/turbine_models/custom_models/resnet_18.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class CompiledResnet18Model(CompiledModule):
5656
params = export_parameters(resnet_model.model)
5757

5858
def main(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)):
59-
const = [x.dynamic_dim(0) < 16]
60-
return jittable(resnet_model.forward)(x, constraints=const)
59+
dynamic_shapes = {"arg0_1": {0: torch.export.Dim("dim", max=15)}}
60+
return jittable(resnet_model.forward)(x, dynamic_shapes=dynamic_shapes)
6161

6262
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
6363
inst = CompiledResnet18Model(context=Context(), import_to=import_to)

models/turbine_models/custom_models/stateless_llama.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,10 @@ class StateUpdateModule(CompiledModule):
237237
def run_initialize(
238238
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
239239
):
240-
init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ]
241-
token, *state = self.initialize(x, constraints=init_const)
240+
dynamic_shapes_init = {
241+
"arg0_1": {1: torch.export.Dim("dim", max=MAX_STEP_SEQ - 1)}
242+
}
243+
token, *state = self.initialize(x, dynamic_shapes=dynamic_shapes_init)
242244
self.global_seq_step = IREE.tensor_dim(
243245
state[0], 1
244246
) # ? dimension of arbitrarily 0th kv tensor
@@ -267,16 +269,15 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
267269
HIDDEN_DIM,
268270
NUM_LAYERS,
269271
)
270-
forw_const = (
271-
[state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
272-
+ [
273-
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1))
274-
for x in state_arg[1:]
275-
]
276-
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
272+
state_arg0_dim = torch.export.Dim(
273+
"state_arg0_dim", max=MAX_STEP_SEQ - 1
277274
)
275+
dynamic_shapes_forw = {"arg0_1": None, "arg1_1": {1: state_arg0_dim}}
276+
for state_arg_idx in range(2, len(state_arg) + 1):
277+
current_dim_dict = {f"arg{state_arg_idx}_1": {1: state_arg0_dim}}
278+
dynamic_shapes_forw = {**dynamic_shapes_forw, **current_dim_dict}
278279
token, *state_update = self.forward(
279-
x, *state_arg, constraints=forw_const
280+
x, *state_arg, dynamic_shapes=dynamic_shapes_forw
280281
)
281282
for i in range(NUM_LAYERS):
282283
update = IREE.tensor_reshape(
@@ -343,17 +344,19 @@ def run_cached_initialize(
343344
HIDDEN_DIM,
344345
NUM_LAYERS,
345346
)
346-
forw_const = (
347-
[x.dynamic_dim(1) < MAX_STEP_SEQ]
348-
+ [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
349-
+ [
350-
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1))
351-
for x in state_arg[1:]
352-
]
353-
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
347+
state_arg0_dim1 = torch.export.Dim(
348+
"state_arg0_dim1", max=MAX_STEP_SEQ - 1
354349
)
350+
x_dim = torch.export.Dim("x_dim", max=MAX_STEP_SEQ - 1)
351+
dynamic_shapes_forw = {
352+
"arg0_1": {1: x_dim},
353+
"arg1_1": {1: state_arg0_dim1},
354+
}
355+
for state_arg_idx in range(2, len(state_arg) + 1):
356+
current_dim_dict = {f"arg{state_arg_idx}_1": {1: state_arg0_dim1}}
357+
dynamic_shapes_forw = {**dynamic_shapes_forw, **current_dim_dict}
355358
token, *state = self.cached_initialize(
356-
x, *state_arg, constraints=forw_const
359+
x, *state_arg, dynamic_shapes=dynamic_shapes_forw
357360
)
358361
len_of_new_tokens = IREE.tensor_dim(
359362
state[0], 1

0 commit comments

Comments
 (0)