Skip to content

Commit e706971

Browse files
authored
Merge pull request #569 from cpcdoy/fix/device
Fix GPU usage by removing `device` from `Transformers` class wrapper to use the device/device_map directly exposed by HF Transformers in kwargs
2 parents 716d7c1 + ba5a4cd commit e706971

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

guidance/models/transformers/_transformers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class Transformers(Model):
16-
def __init__(self, model=None, tokenizer=None, echo=True, caching=True, temperature=0.0, compute_log_probs=False, device=None, **kwargs):
16+
def __init__(self, model=None, tokenizer=None, echo=True, caching=True, temperature=0.0, compute_log_probs=False, **kwargs):
1717
'''Build a new Transformers model object that represents a model in a given state.'''
1818

1919
# fill in default model value
@@ -34,8 +34,6 @@ def __init__(self, model=None, tokenizer=None, echo=True, caching=True, temperat
3434
# self.current_time = time.time()
3535
# self.call_history = collections.deque()
3636
self.temperature = temperature
37-
if device is not None: # set the device if requested
38-
self.model_obj = self.model_obj.to(device)
3937
self.device = self.model_obj.device # otherwise note the current device
4038

4139
# build out the set of byte_string tokens

0 commit comments

Comments
 (0)