You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to build a n-gram counter with Jax, all sentences are tokenized before the process so I will mostly be interfacing with integers rather than strings.
I'm stuck on how to store counts in a data structure that will grow as the process streams more samples. I've tried using python dictionary to store the counts for each token as the key but that doesn't work. Either because it's unhashable or if I start with a dictionary with content the error becomes something about how the the variable needs to be the same shape.
Here is a simple process I thought up that doesn't work.
import jax
tree = {}
ngram = [1,2,3]
def ngram_increment(tree, ngram):
_tree = tree
def true_fn(tree, token):
return tree
def false_fn(tree, token):
tree[token] = {"count": 0}
return tree
for token in ngram:
_tree = jax.lax.cond(
token in _tree,
true_fn,
false_fn,
_tree, token
)
_tree = _tree[token]
_tree["count"] += 1
return tree
tree = ngram_increment(tree, ngram)
I'm hoping to leverage Jax because it seems the process of ngram counting can be done with accelerators to massively speedup the process.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to build a n-gram counter with Jax, all sentences are tokenized before the process so I will mostly be interfacing with integers rather than strings.
I'm stuck on how to store counts in a data structure that will grow as the process streams more samples. I've tried using python dictionary to store the counts for each token as the key but that doesn't work. Either because it's unhashable or if I start with a dictionary with content the error becomes something about how the the variable needs to be the same shape.
Here is a simple process I thought up that doesn't work.
I'm hoping to leverage Jax because it seems the process of ngram counting can be done with accelerators to massively speedup the process.
Beta Was this translation helpful? Give feedback.
All reactions