We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fa5e448 commit c7c0183Copy full SHA for c7c0183
flame/data.py
@@ -40,12 +40,12 @@ def __init__(
40
self.world_size = world_size
41
self.buffer_size = buffer_size
42
43
- if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
44
- self.dtype = torch.int16
45
- elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
46
- self.dtype = torch.int32
+ if tokenizer.vocab_size < torch.iinfo(torch.uint16).max:
+ self.dtype = torch.uint16
+ elif tokenizer.vocab_size < torch.iinfo(torch.uint32).max:
+ self.dtype = torch.uint32
47
else:
48
- self.dtype = torch.int64
+ self.dtype = torch.uint64
49
self.states = None
50
self.buffer = torch.tensor([], dtype=self.dtype)
51
self.tokens = []
0 commit comments