Skip to content

Commit efa87b1

Browse files
committed
Migrate functionality into UI toggle
1 parent 32a72e7 commit efa87b1

File tree

4 files changed

+17
-5
lines changed

4 files changed

+17
-5
lines changed

modules/ui/TrainUI.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from modules.ui.TrainingTab import TrainingTab
3030
from modules.ui.VideoToolUI import VideoToolUI
3131
from modules.util import create
32+
from modules.util.attn.flash_attn_win import disable_flash_attn_win, enable_flash_attn_win
3233
from modules.util.callbacks.TrainCallbacks import TrainCallbacks
3334
from modules.util.commands.TrainCommands import TrainCommands
3435
from modules.util.config.TrainConfig import TrainConfig
@@ -133,6 +134,7 @@ def __init__(self):
133134
self.always_on_tensorboard_subprocess = None
134135
self.current_workspace_dir = self.train_config.workspace_dir
135136
self._check_start_always_on_tensorboard()
137+
self._flash_attn_fallback_toggle()
136138

137139
self.workspace_dir_trace_id = self.ui_state.add_var_trace("workspace_dir", self._on_workspace_dir_change_trace)
138140

@@ -335,6 +337,10 @@ def create_general_tab(self, master):
335337
tooltip="The device used to temporarily offload models while they are not used. Default:\"cpu\"")
336338
components.entry(frame, 16, 1, self.ui_state, "temp_device")
337339

340+
components.label(frame, 17, 0, "Use Flash-Attention Fallback",
341+
tooltip="Enables Flash-Attention fallback on Windows if native support is not available in PyTorch for a performance improvement during training/sampling.")
342+
components.switch(frame, 17, 1, self.ui_state, "use_flash_attn_fallback", command=self._flash_attn_fallback_toggle)
343+
338344
frame.pack(fill="both", expand=1)
339345
return frame
340346

@@ -913,3 +919,9 @@ def _set_training_button_running(self):
913919

914920
def _set_training_button_stopping(self):
915921
self._set_training_button_style("stopping")
922+
923+
def _flash_attn_fallback_toggle(self):
924+
if self.train_config.use_flash_attn_fallback:
925+
enable_flash_attn_win()
926+
else:
927+
disable_flash_attn_win()

modules/util/attn/flash_attn_win.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def _flash_dynamic_scaled_dot_product_attention(query: torch.Tensor,
146146
dropout_p: float = 0.0,
147147
is_causal: bool = False,
148148
scale: float | None = None,
149-
enable_gqa: bool = False):
149+
enable_gqa: bool = False,
150+
_fallback_sdpa = _scaled_dot_product_attention):
150151
if can_use_flash_attn(query, key, value, attn_mask, is_causal, enable_gqa):
151152
# transpose(1,2) is equivalent to permute(0,2,1,3) for (B,H,L,D) -> (B,L,H,D)
152153
q = query.transpose(1, 2)
@@ -161,7 +162,7 @@ def _flash_dynamic_scaled_dot_product_attention(query: torch.Tensor,
161162
return out.transpose(1, 2)
162163

163164
# Fallback
164-
return _scaled_dot_product_attention(
165+
return _fallback_sdpa(
165166
query=query, key=key, value=value,
166167
attn_mask=attn_mask, dropout_p=dropout_p,
167168
is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)

modules/util/config/TrainConfig.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ class TrainConfig(BaseConfig):
377377
loss_scaler: LossScaler
378378
learning_rate_scaler: LearningRateScaler
379379
clip_grad_norm: float
380+
use_flash_attn_fallback: bool
380381

381382
#layer filter
382383
layer_filter: str # comma-separated
@@ -931,6 +932,7 @@ def default_values() -> 'TrainConfig':
931932
data.append(("loss_scaler", LossScaler.NONE, LossScaler, False))
932933
data.append(("learning_rate_scaler", LearningRateScaler.NONE, LearningRateScaler, False))
933934
data.append(("clip_grad_norm", 1.0, float, True))
935+
data.append(("use_flash_attn_fallback", True, bool, False))
934936

935937
# noise
936938
data.append(("offset_noise_weight", 0.0, float, False))

scripts/util/import_util.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,3 @@ def script_imports(allow_zluda: bool = True):
3131
from modules.zluda import ZLUDA
3232

3333
ZLUDA.initialize()
34-
35-
from modules.util.attn.flash_attn_win import enable_flash_attn_win
36-
enable_flash_attn_win()

0 commit comments

Comments
 (0)