-
Notifications
You must be signed in to change notification settings - Fork 720
Description
Describe the bug
When attempting to call model.predict(input), a ValueError pops up that I haven't encountered before with SimpleTransformers.
To Reproduce
transformers==4.55.4, simpletransformers==0.70.5
Sample code:
model_args = MultiLabelClassificationArgs(
num_train_epochs=N_EPOCHS,
overwrite_output_dir=True,
train_batch_size=32,
output_dir=args.output_path,
use_multiprocessing=False,
use_multiprocessing_for_evaluation=False,
save_steps = -1, # set -1 to disable
save_eval_checkpoints = False,
save_model_every_epoch = False
)
logger.info('model args: ')
logger.info(model_args)
model = MultiLabelClassificationModel(
'roberta',
'distilroberta-base',
use_cuda=True,
args=model_args,
num_labels=1
)
# Train the model
model.train_model(train)
test_data = list(test['Text'])
logger.info(f'Test len: {test.shape}') # Test len: (66, 2)
# Evaluate on test set
predictions, raw_outputs = model.predict(test_data)
Screenshots
Predicting: 0%| | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
File "train.py", line 135, in
predictions, raw_outputs = model.predict(test_data)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.12/site-packages/simpletransformers/classification/multi_label_classification_model.py", line 406, in predict
return super().predict(to_predict, multi_label=multi_label)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.12/site-packages/simpletransformers/classification/classification_model.py", line 2295, in predict
out_label_ids[start_index:end_index] = (
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
ValueError: could not broadcast input array from shape (66,) into shape (66,1)
Desktop (please complete the following information):
- Ubuntu
Additional context
My dataset has two columns: text and labels
@ThilinaRajapakse Thank you so much for fixing the previous issues, I really appreciate it!
I ran into this new issue when trying to make predictions with the trained model. I’m not sure if it’s something specific to my environment, but I didn’t encounter this before, and I haven’t changed any code or packages aside from upgrading these two (transformers==4.55.4, simpletransformers==0.70.5).
Happy to share more details if needed!