Skip to content

Commit 4a925b1

Browse files
authored
[Batched Storage] Overwrite file when batch_step=0 (#432)
1 parent e1800c7 commit 4a925b1

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

dataflow/utils/storage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1060,7 +1060,8 @@ def clean_surrogates(obj):
10601060
os.makedirs(os.path.dirname(file_path), exist_ok=True)
10611061
self.logger.success(f"Writing data to {file_path} with type {self.cache_type}")
10621062
if self.cache_type == "jsonl":
1063-
with open(file_path, 'a', encoding='utf-8') as f:
1063+
open_mode = 'w' if self.batch_step == 0 else 'a'
1064+
with open(file_path, open_mode, encoding='utf-8') as f:
10641065
dataframe.to_json(f, orient="records", lines=True, force_ascii=False)
10651066
elif self.cache_type == "csv":
10661067
if self.batch_step == 0:

test/test_batched_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,5 @@ def forward(self):
5656
if __name__ == "__main__":
5757
pipeline = AutoOPPipeline()
5858
pipeline.compile()
59-
pipeline.forward(batch_size=2, resume_from_last=True)
59+
pipeline.forward(batch_size=2, resume_from_last=True)
60+
pipeline.forward(batch_size=2, resume_from_last=False) # should overwrite

0 commit comments

Comments
 (0)