Skip to content

Commit a817edd

Browse files
committed
Check size on transformer cache
Related: marian-nmt#881
1 parent e27da62 commit a817edd

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/models/transformer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class Transformer : public EncoderOrDecoderBase {
290290
// memoization propagation (short-term)
291291
if (cache // if caching
292292
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
293-
&& cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
293+
&& cache_[prefix + "_keys"]->shape() == keys->shape()) { // and the underlying shape did not change
294294
kh = cache_[prefix + "_keys"]; // then return cached tensor
295295
}
296296
else {
@@ -306,7 +306,7 @@ class Transformer : public EncoderOrDecoderBase {
306306
Expr vh;
307307
if (cache
308308
&& cache_.count(prefix + "_values") > 0
309-
&& cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
309+
&& cache_[prefix + "_values"]->shape() == values->shape()) {
310310
vh = cache_[prefix + "_values"];
311311
} else {
312312
int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation

0 commit comments

Comments
 (0)