Skip to content

Commit 96d891c

Browse files
Speedup on some models by not upcasting bfloat16 to float32 on mac.
1 parent 4553891 commit 96d891c

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

comfy/ldm/modules/attention.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@
3030

3131
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
3232

33-
def get_attn_precision(attn_precision):
33+
def get_attn_precision(attn_precision, current_dtype):
3434
if args.dont_upcast_attention:
3535
return None
36-
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
37-
return FORCE_UPCAST_ATTENTION_DTYPE
36+
37+
if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
38+
return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
3839
return attn_precision
3940

4041
def exists(val):
@@ -81,7 +82,7 @@ def Normalize(in_channels, dtype=None, device=None):
8182
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
8283

8384
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
84-
attn_precision = get_attn_precision(attn_precision)
85+
attn_precision = get_attn_precision(attn_precision, q.dtype)
8586

8687
if skip_reshape:
8788
b, _, _, dim_head = q.shape
@@ -150,7 +151,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
150151

151152

152153
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
153-
attn_precision = get_attn_precision(attn_precision)
154+
attn_precision = get_attn_precision(attn_precision, query.dtype)
154155

155156
if skip_reshape:
156157
b, _, _, dim_head = query.shape
@@ -220,7 +221,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
220221
return hidden_states
221222

222223
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
223-
attn_precision = get_attn_precision(attn_precision)
224+
attn_precision = get_attn_precision(attn_precision, q.dtype)
224225

225226
if skip_reshape:
226227
b, _, _, dim_head = q.shape

comfy/model_management.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ def force_upcast_attention_dtype():
954954
upcast = True
955955

956956
if upcast:
957-
return torch.float32
957+
return {torch.float16: torch.float32}
958958
else:
959959
return None
960960

0 commit comments

Comments
 (0)