diff --git a/llmc/data/dataset/base_dataset.py b/llmc/data/dataset/base_dataset.py index 55c36f84a..7af3de73a 100755 --- a/llmc/data/dataset/base_dataset.py +++ b/llmc/data/dataset/base_dataset.py @@ -25,7 +25,7 @@ def __init__(self, tokenizer, calib_cfg, batch_process=None): self.load_from_txt = calib_cfg.get('load_from_txt', False) self.calib_dataset_path = calib_cfg.get('path', None) self.apply_chat_template = calib_cfg.get('apply_chat_template', False) - self.n_samples = calib_cfg.get('seq_len', None) + self.n_samples = calib_cfg.get('n_samples', None) self.calib_bs = calib_cfg['bs'] if self.calib_dataset_name in ['t2v', 'i2v']: assert self.calib_bs == 1