|
30 | 30 |
|
31 | 31 | FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype() |
32 | 32 |
|
33 | | -def get_attn_precision(attn_precision): |
| 33 | +def get_attn_precision(attn_precision, current_dtype): |
34 | 34 | if args.dont_upcast_attention: |
35 | 35 | return None |
36 | | - if FORCE_UPCAST_ATTENTION_DTYPE is not None: |
37 | | - return FORCE_UPCAST_ATTENTION_DTYPE |
| 36 | + |
| 37 | + if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE: |
| 38 | + return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype] |
38 | 39 | return attn_precision |
39 | 40 |
|
40 | 41 | def exists(val): |
@@ -81,7 +82,7 @@ def Normalize(in_channels, dtype=None, device=None): |
81 | 82 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) |
82 | 83 |
|
83 | 84 | def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): |
84 | | - attn_precision = get_attn_precision(attn_precision) |
| 85 | + attn_precision = get_attn_precision(attn_precision, q.dtype) |
85 | 86 |
|
86 | 87 | if skip_reshape: |
87 | 88 | b, _, _, dim_head = q.shape |
@@ -150,7 +151,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape |
150 | 151 |
|
151 | 152 |
|
152 | 153 | def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): |
153 | | - attn_precision = get_attn_precision(attn_precision) |
| 154 | + attn_precision = get_attn_precision(attn_precision, query.dtype) |
154 | 155 |
|
155 | 156 | if skip_reshape: |
156 | 157 | b, _, _, dim_head = query.shape |
@@ -220,7 +221,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, |
220 | 221 | return hidden_states |
221 | 222 |
|
222 | 223 | def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): |
223 | | - attn_precision = get_attn_precision(attn_precision) |
| 224 | + attn_precision = get_attn_precision(attn_precision, q.dtype) |
224 | 225 |
|
225 | 226 | if skip_reshape: |
226 | 227 | b, _, _, dim_head = q.shape |
|
0 commit comments