@@ -61,21 +61,18 @@ def test_advance_with_all_repeats_gets_blocked(self):
6161 # (but it's still the best score, thus we have
6262 # [BLOCKED_SCORE, -inf, -inf, -inf, -inf]
6363 expected_scores = torch .tensor (
64- [0 ] + [- float ('inf' )] * (beam_sz - 1 ))\
65- .repeat (batch_sz , 1 )
66- expected_scores [:, 0 ] = self .BLOCKED_SCORE
64+ [self .BLOCKED_SCORE ] + [- float ('inf' )] * (beam_sz - 1 )
65+ ).repeat (batch_sz , 1 )
6766 self .assertTrue (beam .topk_log_probs .equal (expected_scores ))
6867 else :
6968 # repetitions keeps maximizing score
7069 # index 0 has been blocked, so repeating=>+0.0 score
7170 # other indexes are -inf so repeating=>BLOCKED_SCORE
7271 # which is higher
7372 expected_scores = torch .tensor (
74- [0 ] + [- float ('inf' )] * (beam_sz - 1 ))\
75- .repeat (batch_sz , 1 )
76- expected_scores [:, :] = self .BLOCKED_SCORE
77- expected_scores = torch .tensor (
78- self .BLOCKED_SCORE ).repeat (batch_sz , beam_sz )
73+ [self .BLOCKED_SCORE ] + [- float ('inf' )] * (beam_sz - 1 )
74+ ).repeat (batch_sz , 1 )
75+ self .assertTrue (beam .topk_log_probs .equal (expected_scores ))
7976
8077 def test_advance_with_some_repeats_gets_blocked (self ):
8178 # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
@@ -137,7 +134,8 @@ def test_advance_with_some_repeats_gets_blocked(self):
137134
138135 expected = torch .full ([batch_sz , beam_sz ], float ("-inf" ))
139136 expected [:, 0 ] = no_repeat_score
140- expected [:, 1 :] = self .BLOCKED_SCORE
137+ expected [:, 1 :3 ] = self .BLOCKED_SCORE
138+ expected [:, 3 :] = float ("-inf" )
141139 self .assertTrue (
142140 beam .topk_log_probs .equal (expected ))
143141
0 commit comments