File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -290,7 +290,7 @@ class Transformer : public EncoderOrDecoderBase {
290290 // memoization propagation (short-term)
291291 if (cache // if caching
292292 && cache_.count (prefix + " _keys" ) > 0 // and the keys expression has been seen
293- && cache_[prefix + " _keys" ]->shape (). elements () == keys->shape (). elements ()) { // and the underlying element size did not change
293+ && cache_[prefix + " _keys" ]->shape () == keys->shape ()) { // and the underlying shape did not change
294294 kh = cache_[prefix + " _keys" ]; // then return cached tensor
295295 }
296296 else {
@@ -306,7 +306,7 @@ class Transformer : public EncoderOrDecoderBase {
306306 Expr vh;
307307 if (cache
308308 && cache_.count (prefix + " _values" ) > 0
309- && cache_[prefix + " _values" ]->shape (). elements () == values->shape (). elements ()) {
309+ && cache_[prefix + " _values" ]->shape () == values->shape ()) {
310310 vh = cache_[prefix + " _values" ];
311311 } else {
312312 int dimValues = values->shape ()[-1 ]; // different than dimModel when using lemma and factors combined with concatenation
You can’t perform that action at this time.
0 commit comments