|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +import numpy as np |
17 | 18 | import math |
18 | 19 | from typing import Optional, Sequence, Tuple, TypeVar, Union |
19 | 20 |
|
@@ -2048,6 +2049,9 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( |
2048 | 2049 | attn_weight, _ = op.Dropout(attn_weight, dropout_p) |
2049 | 2050 | return op.MatMul(attn_weight, value) |
2050 | 2051 |
|
| 2052 | +def float_lowest(dtype): |
| 2053 | + """Returns the lowest representable value for the given numpy dtype.""" |
| 2054 | + return np.finfo(np.dtype(dtype)).min |
2051 | 2055 |
|
2052 | 2056 | def _aten_scaled_dot_product_attention_bool_mask_onnx( |
2053 | 2057 | query: TFloat, |
@@ -2078,7 +2082,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( |
2078 | 2082 | key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) |
2079 | 2083 | # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) |
2080 | 2084 | zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype)) |
2081 | | - neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype)) |
| 2085 | + neg_inf = op.Constant(value=ir.tensor(float_lowest(query.dtype)), dtype=query.dtype) |
2082 | 2086 | attn_mask = op.Where(attn_mask, zero, neg_inf) |
2083 | 2087 | attn_weight = op.Softmax( |
2084 | 2088 | op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), |
|
0 commit comments