|
12 | 12 | // See the License for the specific language governing permissions and
|
13 | 13 | // limitations under the License.
|
14 | 14 | #include <chrono>
|
15 |
| -#include <limits> |
16 | 15 | #include <random>
|
17 | 16 |
|
18 | 17 | #include "habanalabs/perf_lib_layer_params.h"
|
|
23 | 22 | #include "kernels/hpu_operator.h"
|
24 | 23 | #include "utils/utils.h"
|
25 | 24 |
|
26 |
| -const float NEG_INF = std::numeric_limits<float>::lowest(); |
27 |
| - |
28 | 25 | namespace custom_kernel {
|
29 | 26 |
|
30 | 27 | class TopP : public HpuFusedOperator {
|
@@ -75,19 +72,18 @@ class TopP : public HpuFusedOperator {
|
75 | 72 | std::vector<synTensor> less_equal_outs = {mask};
|
76 | 73 | AddNodeLessEqual<T>(less_equal_ins, less_equal_outs, guid_ + "less_equal");
|
77 | 74 |
|
78 |
| - // Scalar Node neg_inf |
| 75 | + // Scalar Node Zero |
79 | 76 | std::vector<int64_t> scalar_dims = {1};
|
80 |
| - auto neg_inf = |
81 |
| - createTensorNoPresist("neg_inf", syn_type_float, scalar_dims); |
| 77 | + auto zero = createTensorNoPresist("zero", syn_type_float, scalar_dims); |
82 | 78 | ns_ConstantKernel::Params const_params;
|
83 |
| - const_params.constant.f = NEG_INF; |
84 |
| - std::vector<synTensor> full_out = {neg_inf}; |
85 |
| - AddNodeFull<T>(full_out, const_params, guid_ + "full_neg_inf"); |
| 79 | + const_params.constant.f = 0; |
| 80 | + std::vector<synTensor> full_out = {zero}; |
| 81 | + AddNodeFull<T>(full_out, const_params, guid_ + "full_zero"); |
86 | 82 |
|
87 | 83 | // Where to populate unwanted probs with -inf
|
88 | 84 | auto filtered_probs =
|
89 | 85 | createTensorNoPresist("filtered_probs", inputs[0].type, inputs[0].dims);
|
90 |
| - std::vector<synTensor> where_ins = {mask, sorted_probs, neg_inf}; |
| 86 | + std::vector<synTensor> where_ins = {mask, sorted_probs, zero}; |
91 | 87 | std::vector<synTensor> where_outs = {filtered_probs};
|
92 | 88 | AddNodeWhere<T>(where_ins, where_outs, guid_ + "where");
|
93 | 89 |
|
|
0 commit comments