Skip to content

Commit 2d2ab96

Browse files
authored
[INTEL_HPU] fix regression caused by multinomial + neg_inf (#1893)
1 parent 281f5ca commit 2d2ab96

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

backends/intel_hpu/kernels/top_p_hpu.cc

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#include <chrono>
15-
#include <limits>
1615
#include <random>
1716

1817
#include "habanalabs/perf_lib_layer_params.h"
@@ -23,8 +22,6 @@
2322
#include "kernels/hpu_operator.h"
2423
#include "utils/utils.h"
2524

26-
const float NEG_INF = std::numeric_limits<float>::lowest();
27-
2825
namespace custom_kernel {
2926

3027
class TopP : public HpuFusedOperator {
@@ -75,19 +72,18 @@ class TopP : public HpuFusedOperator {
7572
std::vector<synTensor> less_equal_outs = {mask};
7673
AddNodeLessEqual<T>(less_equal_ins, less_equal_outs, guid_ + "less_equal");
7774

78-
// Scalar Node neg_inf
75+
// Scalar Node Zero
7976
std::vector<int64_t> scalar_dims = {1};
80-
auto neg_inf =
81-
createTensorNoPresist("neg_inf", syn_type_float, scalar_dims);
77+
auto zero = createTensorNoPresist("zero", syn_type_float, scalar_dims);
8278
ns_ConstantKernel::Params const_params;
83-
const_params.constant.f = NEG_INF;
84-
std::vector<synTensor> full_out = {neg_inf};
85-
AddNodeFull<T>(full_out, const_params, guid_ + "full_neg_inf");
79+
const_params.constant.f = 0;
80+
std::vector<synTensor> full_out = {zero};
81+
AddNodeFull<T>(full_out, const_params, guid_ + "full_zero");
8682

8783
// Where to populate unwanted probs with -inf
8884
auto filtered_probs =
8985
createTensorNoPresist("filtered_probs", inputs[0].type, inputs[0].dims);
90-
std::vector<synTensor> where_ins = {mask, sorted_probs, neg_inf};
86+
std::vector<synTensor> where_ins = {mask, sorted_probs, zero};
9187
std::vector<synTensor> where_outs = {filtered_probs};
9288
AddNodeWhere<T>(where_ins, where_outs, guid_ + "where");
9389

0 commit comments

Comments
 (0)