File tree Expand file tree Collapse file tree 2 files changed +8
-4
lines changed Expand file tree Collapse file tree 2 files changed +8
-4
lines changed Original file line number Diff line number Diff line change 12
12
from torch .nn import functional as F
13
13
14
14
def prepare_inputs_for_model (inps ):
15
+ # this is because input from lm-eval is 2d
16
+ if input .dim () != 2 :
17
+ raise ValueError (f"Expected input to be of dim 2, but got { input .dim ()} " )
18
+
15
19
inps = inps .squeeze (0 )
16
20
# setup inputs in correct format
17
21
max_new_tokens = 1
Original file line number Diff line number Diff line change @@ -1225,9 +1225,9 @@ def __init__(
1225
1225
calibration_limit ,
1226
1226
calibration_seq_length ,
1227
1227
pad_calibration_inputs ,
1228
- inner_k_tiles = 8 ,
1229
- padding_allowed = True ,
1230
- precision = torch . float32 ,
1231
- _is_gpt_fast = True ,
1228
+ inner_k_tiles = inner_k_tiles ,
1229
+ padding_allowed = padding_allowed ,
1230
+ precision = precision ,
1231
+ _is_gpt_fast = _is_gpt_fast ,
1232
1232
_use_cuda = _use_cuda ,
1233
1233
)
You can’t perform that action at this time.
0 commit comments