|
7 | 7 | from urllib.parse import urlparse |
8 | 8 |
|
9 | 9 |
|
10 | | -def write_dataset_to_s3( |
11 | | - df: pd.DataFrame, bucket: str, key_prefix: str, format: str |
| 10 | +def write_dataset( |
| 11 | + df: pd.DataFrame, data_dir: str, file_prefix: str, format: str |
| 12 | +) -> str: |
| 13 | + if is_s3_path(data_dir): |
| 14 | + bucket, prefix = parse_s3_path(data_dir) |
| 15 | + return write_dataset_s3(df, bucket, prefix, file_prefix, format) |
| 16 | + else: |
| 17 | + return write_dataset_local(df, data_dir, file_prefix, format) |
| 18 | + |
| 19 | +def write_dataset_s3( |
| 20 | + df: pd.DataFrame, bucket: str, prefix: str, file_prefix: str, format: str |
12 | 21 | ) -> str: |
13 | 22 | with tempfile.TemporaryDirectory() as temp_dir: |
14 | | - temp_file = os.path.join(temp_dir, "temp.jsonl") |
15 | | - df.to_json(temp_file, orient="records", lines=bool(format == "jsonl")) |
| 23 | + temp_file = os.path.join(temp_dir, "temp.csv") |
| 24 | + df.to_csv(temp_file, index=False) |
16 | 25 | s3_client = boto3.client("s3") |
17 | | - key = add_timestamp_to_file_prefix(key_prefix, format) |
18 | | - print(f"Writing dataset to bucket {bucket} and key {key}.") |
| 26 | + key = os.path.join(prefix, |
| 27 | + add_timestamp_to_file_prefix(file_prefix, format) |
| 28 | + ) |
| 29 | + print(f"Writing dataset to s3://{bucket}/{key}") |
19 | 30 | s3_client.upload_file(temp_file, bucket, key) |
20 | 31 | return f"s3://{bucket}/{key}" |
21 | 32 |
|
22 | | - |
23 | | -def write_dataset_local(df: pd.DataFrame, data_dir: str, file_prefix: str) -> str: |
| 33 | +def write_dataset_local( |
| 34 | + df: pd.DataFrame, data_dir: str, file_prefix: str, format: str |
| 35 | + ) -> str: |
24 | 36 | # Expand home directory and create if needed |
25 | 37 | data_dir = os.path.expanduser(data_dir) |
26 | 38 | os.makedirs(data_dir, exist_ok=True) |
27 | 39 |
|
28 | 40 | output_path = os.path.join( |
29 | | - data_dir, add_timestamp_to_file_prefix(file_prefix, "csv") |
| 41 | + data_dir, |
| 42 | + add_timestamp_to_file_prefix(file_prefix, format) |
30 | 43 | ) |
31 | 44 | df.to_csv(output_path, index=False) |
32 | 45 | print(f"Saved to {output_path}") |
|
0 commit comments