Disable compilation for training_step, validation_step, etc. #21569
Replies: 1 comment 4 replies
-
|
ran into this myself when trying to optimize my workflow. it can be frustrating when everything seems to compile when you just want specific methods. the issue stems from how Fabric handles the model setup. once you apply to just compile # Create model
model = MyModel()
# Compile only the forward method
model.forward = torch.compile(model.forward)
# Setup with Fabric
model, optimizer = fabric.setup(model, optimizer)this approach directly compiles only the another possible way, in case your logic gets even more complex, is to separate concerns: isolate computation-heavy parts of your keep in mind, though, this solution might not cover special Fabric internals if they change in future updates. always good to check out their documentation or raise an issue if new behaviors arise. let me know if this approach changes anything for your use case! |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi! I am training a model with Fabric and would like to compile only the
forwardmethod of the model and skip compilation for thetraining_step,validation_step, etc. The reason for this is thattraining_stepcontains hard to compile logic like metrics tracking and has dynamic inputs.forwardhowever is safe to compile.It seems that the forward wrapping magic in Fabric automatically results in
training_stepto be compiled as well and I haven't managed to excludetraining_stepfrom compilation. I tried wrapping it in@torch.compiler.disable()but that didn't work. Callingmodel.training_step = torch.compiler.disable(model.training_step, recursive=False)doesn't work if called beforefabric.setupand results in a recursion depth exceeded error if called afterfabric.setup.Any hints on how to achieve this or what the best practice is for this scenario?
Code to reproduce:
Output:
Here the compile limit is reached because
filenameis different in every training step.What does work is calling torch.compile only on the forward:
But then
fabric.setupdoesn't re-apply compilation because not the whole model was compiled. This is at least my understanding.Beta Was this translation helpful? Give feedback.
All reactions