|
29 | 29 | from modules.ui.TrainingTab import TrainingTab |
30 | 30 | from modules.ui.VideoToolUI import VideoToolUI |
31 | 31 | from modules.util import create |
| 32 | +from modules.util.attn.flash_attn_win import disable_flash_attn_win, enable_flash_attn_win |
32 | 33 | from modules.util.callbacks.TrainCallbacks import TrainCallbacks |
33 | 34 | from modules.util.commands.TrainCommands import TrainCommands |
34 | 35 | from modules.util.config.TrainConfig import TrainConfig |
@@ -133,6 +134,7 @@ def __init__(self): |
133 | 134 | self.always_on_tensorboard_subprocess = None |
134 | 135 | self.current_workspace_dir = self.train_config.workspace_dir |
135 | 136 | self._check_start_always_on_tensorboard() |
| 137 | + self._flash_attn_fallback_toggle() |
136 | 138 |
|
137 | 139 | self.workspace_dir_trace_id = self.ui_state.add_var_trace("workspace_dir", self._on_workspace_dir_change_trace) |
138 | 140 |
|
@@ -335,6 +337,10 @@ def create_general_tab(self, master): |
335 | 337 | tooltip="The device used to temporarily offload models while they are not used. Default:\"cpu\"") |
336 | 338 | components.entry(frame, 16, 1, self.ui_state, "temp_device") |
337 | 339 |
|
| 340 | + components.label(frame, 17, 0, "Use Flash-Attention Fallback", |
| 341 | + tooltip="Enables Flash-Attention fallback on Windows if native support is not available in PyTorch for a performance improvement during training/sampling.") |
| 342 | + components.switch(frame, 17, 1, self.ui_state, "use_flash_attn_fallback", command=self._flash_attn_fallback_toggle) |
| 343 | + |
338 | 344 | frame.pack(fill="both", expand=1) |
339 | 345 | return frame |
340 | 346 |
|
@@ -913,3 +919,9 @@ def _set_training_button_running(self): |
913 | 919 |
|
914 | 920 | def _set_training_button_stopping(self): |
915 | 921 | self._set_training_button_style("stopping") |
| 922 | + |
| 923 | + def _flash_attn_fallback_toggle(self): |
| 924 | + if self.train_config.use_flash_attn_fallback: |
| 925 | + enable_flash_attn_win() |
| 926 | + else: |
| 927 | + disable_flash_attn_win() |
0 commit comments