Need Help with this Trax Error Code #6295
Unanswered
memora0101
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
for sent in train_x:
inputs = np.array(sent_to_tensor(sent, vocab_dict=Vocab))
inputs = inputs[None, :]
predictions=model(inputs)
print(f'example input_str: {sent}')
print(f'Model returned sentiment probabilities: {predictions}')
Error:
LayerError Traceback (most recent call last)
in
2 inputs = np.array(sent_to_tensor(sent, vocab_dict=Vocab))
3 inputs = inputs[None, :]
----> 4 predictions=model(inputs)
5 print(f'example input_str: {sent}')
6 print(f'Model returned sentiment probabilities: {predictions}')
~/opt/anaconda3/lib/python3.7/site-packages/trax/layers/base.py in call(self, x, weights, state, rng)
190 self.state = state # Needed if the model wasn't fully initialized.
191 state = self.state
--> 192 outputs, new_state = self.pure_fn(x, weights, state, rng)
193 self.state = new_state
194 self.weights = weights
~/opt/anaconda3/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
547 name, trace = self._name, _short_traceback(skip=3)
548 raise LayerError(name, 'pure_fn',
--> 549 self._caller, signature(x), trace) from None
550
551 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/, line 27
layer input shapes: ShapeDtype{shape:(1, 0), dtype:float32}
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Embedding_437_256 (in pure_fn):
layer created in file [...]/, line 8
layer input shapes: ShapeDtype{shape:(1, 0), dtype:float32}
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/core.py, line 181, in forward
embedded = jnp.take(self.weights, x, axis=0)
File [...]/_src/numpy/lax_numpy.py, line 4078, in take
slice_sizes=tuple(slice_sizes))
File [...]/_src/lax/lax.py, line 874, in gather
slice_sizes=canonicalize_shape(slice_sizes))
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/site-packages/jax/core.py, line 628, in process_primitive
return primitive.impl(*tracers, **params)
File [...]/jax/interpreters/xla.py, line 238, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File [...]/jax/_src/util.py, line 198, in wrapper
return cached(bool(FLAGS.jax_enable_x64), *args, **kwargs)
File [...]/jax/_src/util.py, line 191, in cached
return f(*args, **kwargs)
File [...]/jax/interpreters/xla.py, line 263, in xla_primitive_callable
aval_out = prim.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 1992, in standard_abstract_eval
shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
File [...]/_src/lax/lax.py, line 4114, in _gather_dtype_rule
raise ValueError("start_indices must have an integer type")
ValueError: start_indices must have an integer type
Beta Was this translation helpful? Give feedback.
All reactions