Skip to content

Commit f53d6a7

Browse files
Tuan TranTuan Tran
authored andcommitted
Merge branch 'main' into tuan/fix_17
2 parents 7e4c89c + 67b7ccc commit f53d6a7

File tree

38 files changed

+268
-266
lines changed

38 files changed

+268
-266
lines changed

lcm/datasets/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def pipeline(self) -> DataPipeline:
5353
self._pipeline = self.builder_func(
5454
self.datasets, self.data_config, gang_rank, world_size
5555
)
56-
assert (
57-
self._pipeline
58-
), f"Cannot build data pipeline from config {self.data_config}"
56+
assert self._pipeline, (
57+
f"Cannot build data pipeline from config {self.data_config}"
58+
)
5959
return self._pipeline
6060

6161
def destroy(self) -> None:

lcm/datasets/batch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,9 @@ def __post_init__(self):
249249

250250
length = len(self.source)
251251

252-
assert (
253-
(self.target is None) or (len(self.target) == length)
254-
), f"all elements in LCMInput should be of the same length, got {len(self.target)} and {length}"
252+
assert (self.target is None) or (len(self.target) == length), (
253+
f"all elements in LCMInput should be of the same length, got {len(self.target)} and {length}"
254+
)
255255

256256
def __len__(self) -> int:
257257
return len(self.source)
@@ -296,9 +296,9 @@ def prepare_input(
296296
)
297297

298298
elif style == LCMStyle.SUPERVISED:
299-
assert (
300-
self.target is not None
301-
), "Missing target embeddings for a supervised batch"
299+
assert self.target is not None, (
300+
"Missing target embeddings for a supervised batch"
301+
)
302302
return get_embeddings_sequence(
303303
src_seqs=self.source,
304304
tgt_seqs=self.target,

lcm/datasets/dataloader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,9 @@ def _tokenize_batch(self, batch: Dict[str, Any]) -> LCMInput:
195195
else:
196196
embs = None
197197
outputs[key] = embs
198-
assert (
199-
outputs["source"] is not None
200-
), "LCMDataLoader requires `source` sequences to be present in batches"
198+
assert outputs["source"] is not None, (
199+
"LCMDataLoader requires `source` sequences to be present in batches"
200+
)
201201
return LCMInput(**outputs)
202202

203203
def iterate_batches(self) -> Iterator[LCMInput]:

lcm/datasets/dataloading.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ def build_dataload_pipeline(
202202
self, rank: int = 0, world_size: int = 1
203203
) -> DataPipelineBuilder:
204204
if world_size > 1:
205-
assert (
206-
self.loading_config.seed is not None
207-
), "for distributed training with `world_size` > 1, `seed` should be set !"
205+
assert self.loading_config.seed is not None, (
206+
"for distributed training with `world_size` > 1, `seed` should be set !"
207+
)
208208
if self.is_validation:
209209
self.set_validation_params(world_size)
210210

@@ -321,12 +321,12 @@ def create_on_the_fly_columns(
321321
self, pipeline: DataPipelineBuilder
322322
) -> DataPipelineBuilder:
323323
if self.dataset_config.source_sequences is not None:
324-
assert (
325-
self.dataset_config.source_column is not None
326-
), f"Expected a source_column - found {self.dataset_config.source_column}"
327-
assert (
328-
self.dataset_config.source_text_column is not None
329-
), f"Expected a source_text_column - found {self.dataset_config.source_text_column}"
324+
assert self.dataset_config.source_column is not None, (
325+
f"Expected a source_column - found {self.dataset_config.source_column}"
326+
)
327+
assert self.dataset_config.source_text_column is not None, (
328+
f"Expected a source_text_column - found {self.dataset_config.source_text_column}"
329+
)
330330

331331
pipeline = pipeline.map(
332332
partial(
@@ -338,12 +338,12 @@ def create_on_the_fly_columns(
338338
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
339339
)
340340
if self.dataset_config.target_sequences is not None:
341-
assert (
342-
self.dataset_config.target_column is not None
343-
), f"Expected a target_column, found {self.dataset_config.target_column}"
344-
assert (
345-
self.dataset_config.target_text_column is not None
346-
), f"Expected a target_text_columns, found {self.dataset_config.target_text_column}"
341+
assert self.dataset_config.target_column is not None, (
342+
f"Expected a target_column, found {self.dataset_config.target_column}"
343+
)
344+
assert self.dataset_config.target_text_column is not None, (
345+
f"Expected a target_text_columns, found {self.dataset_config.target_text_column}"
346+
)
347347

348348
pipeline = pipeline.map(
349349
partial(
@@ -426,9 +426,9 @@ def config_post_init(self) -> None:
426426
)
427427

428428
if self.loading_config.even_sharding:
429-
assert (
430-
self.loading_config.seed is not None
431-
), "`even_sharding` sharding requires to seed to be set"
429+
assert self.loading_config.seed is not None, (
430+
"`even_sharding` sharding requires to seed to be set"
431+
)
432432

433433
if self.loading_config.max_tokens == 0:
434434
self.loading_config.max_tokens = None
@@ -876,9 +876,9 @@ def add_min_max_sentence_len_in_doc_filter(
876876
self.loading_config.max_sentence_len_in_doc
877877
or self.loading_config.min_sentence_len_in_doc
878878
):
879-
assert (
880-
self.dataset_config.source_text_column is not None
881-
), f"Expexted a source_text_columns, found {self.dataset_config.source_text_column}"
879+
assert self.dataset_config.source_text_column is not None, (
880+
f"Expexted a source_text_columns, found {self.dataset_config.source_text_column}"
881+
)
882882

883883
pipeline = pipeline.map(
884884
partial(
@@ -962,9 +962,9 @@ def add_quality_score_filters(
962962
if source_quality_range is None:
963963
return pipeline
964964

965-
assert (
966-
self.dataset_config.source_quality_column is not None
967-
), f"Expected a source_quality_columns, found {self.dataset_config.source_quality_column}"
965+
assert self.dataset_config.source_quality_column is not None, (
966+
f"Expected a source_quality_columns, found {self.dataset_config.source_quality_column}"
967+
)
968968

969969
pipeline = pipeline.map(
970970
partial(

lcm/datasets/parquet_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -486,9 +486,9 @@ def build_batching_loop_over_one_table(
486486
num_parallel_calls: int = 1,
487487
) -> DataPipeline:
488488
if max_tokens is not None:
489-
assert (
490-
length_column is not None
491-
), "Need to provide a column to compute the number of tokens"
489+
assert length_column is not None, (
490+
"Need to provide a column to compute the number of tokens"
491+
)
492492

493493
random_state = np.random.RandomState(seed)
494494
if length_column is not None and len(length_column) > 0:
@@ -1109,9 +1109,9 @@ def get_row_group_level_metadata(
11091109
columns_to_exclude = set(["row_group_id", "num_rows", "total_byte_size"]) & set(
11101110
columns
11111111
)
1112-
assert (
1113-
len(columns_to_exclude) == 0
1114-
), f"names conflict, rename/remove : {columns_to_exclude}"
1112+
assert len(columns_to_exclude) == 0, (
1113+
f"names conflict, rename/remove : {columns_to_exclude}"
1114+
)
11151115

11161116
def get_one_row_group_stats(row_group):
11171117
metadata = row_group.metadata

lcm/evaluation/arun.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ def run(self, iteration_value: Optional[Any] = None, iteration_index: int = 0):
114114
)
115115

116116
if iteration_value is not None:
117-
assert (
118-
isinstance(iteration_value, int) and self.config.nshards
119-
), f"Invalid shard value ({self.config.nshards}) or iteration value ({iteration_value})"
120-
assert (
121-
self.config.data_loading
122-
), f"Data loading is not specified: \n {self.config}"
117+
assert isinstance(iteration_value, int) and self.config.nshards, (
118+
f"Invalid shard value ({self.config.nshards}) or iteration value ({iteration_value})"
119+
)
120+
assert self.config.data_loading, (
121+
f"Data loading is not specified: \n {self.config}"
122+
)
123123
self.config.data_loading.rank = iteration_value
124124
self.config.data_loading.world_size = int(self.config.nshards)
125125

@@ -194,9 +194,9 @@ async def schedule_task(
194194
result = (metrics, result_file)
195195

196196
result_metrics, result_file = result
197-
assert isinstance(
198-
result_metrics, dict
199-
), f"Expected Tuple[Dict[str, AverageMetrics], str], get {type(result_metrics)}"
197+
assert isinstance(result_metrics, dict), (
198+
f"Expected Tuple[Dict[str, AverageMetrics], str], get {type(result_metrics)}"
199+
)
200200

201201
metrics = {}
202202
cf = getattr(module.config, "confidence_level", None)

lcm/evaluation/cli/configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ class CliConfig:
7878

7979
def __post_init__(self) -> None:
8080
self.metric_log_dir = self.metric_log_dir or self.dump_dir
81-
assert (
82-
self.temperature >= 0.0
83-
), f"Expect non-zero temperature, get {self.temperature}"
81+
assert self.temperature >= 0.0, (
82+
f"Expect non-zero temperature, get {self.temperature}"
83+
)
8484
if self.temperature == 0:
8585
self.top_p = 0
8686
self.top_k = 0

lcm/evaluation/metrics/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def get_scorer(
5757
if "outputs" in defaults:
5858
output_columns = defaults["outputs"].default
5959
else:
60-
assert (
61-
config.model_name
62-
), f"Cannot resolve output name for the scorer type {scorer_cls}"
60+
assert config.model_name, (
61+
f"Cannot resolve output name for the scorer type {scorer_cls}"
62+
)
6363
output_columns = scorer_cls.default_outputs(config.model_name)
6464

6565
if isinstance(output_columns, str):

lcm/evaluation/metrics/multilingual_similarity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def translate(
6868
) -> List[str]:
6969
src_lang, tgt_lang = self.src_lang, self.tgt_lang
7070
sent_translations = []
71-
assert isinstance(
72-
self.model, EncoderDecoderModel
73-
), f"Unsupported type: {type(self.model)}"
71+
assert isinstance(self.model, EncoderDecoderModel), (
72+
f"Unsupported type: {type(self.model)}"
73+
)
7474
generator = BeamSearchSeq2SeqGenerator(
7575
self.model, echo_prompt=True, max_seq_len=self.max_seq_len
7676
)

lcm/evaluation/metrics/sentence_fluency.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ def score_texts(
200200
bos_token = bos_token or getattr(self.tokenizer, "bos_token", "\n")
201201
if eos_token != "":
202202
eos_token = eos_token or getattr(self.tokenizer, "eos_token", "\n")
203-
assert (
204-
eos_token is not None and bos_token is not None
205-
), "Expecting eos and bos tokens, for perplexity without any surrounding tokens, use eos_token='' and bos_token=''"
203+
assert eos_token is not None and bos_token is not None, (
204+
"Expecting eos and bos tokens, for perplexity without any surrounding tokens, use eos_token='' and bos_token=''"
205+
)
206206
logger.info(
207207
f"Computing perplexity with bos_token={repr(bos_token)} and eos_token={repr(eos_token)}"
208208
)
@@ -340,9 +340,9 @@ def backtranslate(
340340
translations = []
341341
back_translations = []
342342
losses = []
343-
assert isinstance(
344-
self.model, EncoderDecoderModel
345-
), f"Unsupported type: {type(self.model)}"
343+
assert isinstance(self.model, EncoderDecoderModel), (
344+
f"Unsupported type: {type(self.model)}"
345+
)
346346
generator = BeamSearchSeq2SeqGenerator(
347347
self.model, echo_prompt=True, max_seq_len=self.max_seq_len
348348
)

0 commit comments

Comments
 (0)