-
Notifications
You must be signed in to change notification settings - Fork 70
[WIP] Add Ray-KFT example for Ray based data processing and Kubeflow Traini… #458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] Add Ray-KFT example for Ray based data processing and Kubeflow Traini… #458
Conversation
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
WalkthroughAdds an end-to-end example pipeline under examples/ray-kft-v1: four notebooks (Ray SDG, Kubeflow training, TensorBoard, inference), a README, a synthetic dataset, and two new scripts for Ray Data synthetic data generation and Granite fine-tuning training entrypoint. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant CodeFlare as CodeFlare SDK
participant Ray as Ray Cluster
participant SDG as SDG Job (ray_sdg_job.py)
participant PVC as Shared Storage
User->>CodeFlare: Authenticate + apply Ray cluster
CodeFlare-->>User: Cluster ready (dashboard URI)
User->>Ray: Submit Ray Data SDG job
Ray->>SDG: Launch driver
SDG->>SDG: Initialize ModelInferenceCallable
loop per batch
SDG->>Ray: map_batches(ModelInferenceCallable)
SDG->>SDG: Parse responses + quality assessment
SDG->>PVC: Save batch shard + checkpoint
end
SDG->>PVC: Write final dataset + metadata
SDG-->>User: Completion status + artifact paths
sequenceDiagram
autonumber
participant User
participant KFT as Kubeflow Training SDK
participant K8s as Kubernetes
participant Workers as PyTorchJob Pods
participant PVC as Shared Storage
User->>KFT: Submit PyTorchJob (kft_granite_training.training_func)
KFT->>K8s: Create Job (master + workers)
K8s-->>KFT: Pods running
Workers->>PVC: Load dataset/models/cache
Workers->>Workers: Run TRL SFTTrainer + PEFT/LoRA training
Workers->>PVC: Save checkpoints/artifacts
User->>K8s: Monitor logs / inspect status
User->>KFT: Delete Job (optional)
sequenceDiagram
autonumber
participant User
participant Notebook as Inference Notebook
participant PVC as Shared Storage
participant HF as HuggingFace Hub
Notebook->>PVC: Attempt load tokenizer/base/fine-tuned
alt Cache miss
Notebook->>HF: Fetch models/tokenizer
end
Notebook->>Notebook: Merge LoRA into model for inference
loop for each test sample
Notebook->>Notebook: Generate base vs tuned outputs
Notebook->>Notebook: Extract numeric answers + compare
end
Notebook-->>User: Accuracy summary + detailed analysis
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
0ac5486 to
9e77072
Compare
…ng-Operator V1 based finetuning capabilities
9e77072 to
6f9e15f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (2)
examples/ray-kft-v1/4_test_inference.ipynb (2)
116-123: Consider more specific exception handling.While broad
Exceptioncatches are flagged by static analysis, they may be acceptable in this notebook context for graceful degradation. However, catching more specific exceptions would improve debugging and make error handling more intentional.Consider catching specific exceptions where possible:
# For file/model loading operations except (OSError, ValueError) as e: print(f"Model loading error: {e}") # fallback logic # For tokenizer/inference operations except (RuntimeError, torch.cuda.CudaError) as e: print(f"Inference error: {e}")This makes it clear which errors are expected and helps distinguish between recoverable errors and unexpected failures.
Also applies to: 149-162, 244-246, 348-349, 466-467, 486-487
351-363: Consider more robust answer extraction.The current evaluation extracts only the final number from text (line 353:
numbers[-1]), which may be too simplistic for math word problems. This approach could miss:
- Answers with multiple numeric components
- Non-numeric answers or units
- Intermediate calculations that happen to be the last number
For a production system, consider:
- More sophisticated parsing (e.g., looking for explicit answer markers like "The answer is")
- Semantic similarity between generated and expected answers
- Multiple evaluation metrics (exact match, numeric similarity, step coverage)
However, for a demonstration notebook, the current simple approach may be sufficient.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (8)
examples/ray-kft-v1/docs/raycluster_dashboard.pngis excluded by!**/*.pngexamples/ray-kft-v1/docs/rayjob_gpu_util.pngis excluded by!**/*.pngexamples/ray-kft-v1/docs/rayjob_running_1.pngis excluded by!**/*.pngexamples/ray-kft-v1/docs/rayjob_running_2.pngis excluded by!**/*.pngexamples/ray-kft-v1/docs/rayjob_succeeded_1.pngis excluded by!**/*.pngexamples/ray-kft-v1/docs/rayjob_succeeded_2.pngis excluded by!**/*.pngexamples/ray-kft-v1/docs/tensorboard_1.pngis excluded by!**/*.pngexamples/ray-kft-v1/docs/tensorboard_2.pngis excluded by!**/*.png
📒 Files selected for processing (8)
examples/ray-kft-v1/1_ray_sdg.ipynb(1 hunks)examples/ray-kft-v1/2_kft_training.ipynb(1 hunks)examples/ray-kft-v1/3_tensorboard_monitoring.ipynb(1 hunks)examples/ray-kft-v1/4_test_inference.ipynb(1 hunks)examples/ray-kft-v1/README.md(1 hunks)examples/ray-kft-v1/dataset/sample_synthetic_dataset.json(1 hunks)examples/ray-kft-v1/scripts/kft_granite_training.py(1 hunks)examples/ray-kft-v1/scripts/ray_sdg_job.py(1 hunks)
🧰 Additional context used
🪛 markdownlint-cli2 (0.18.1)
examples/ray-kft-v1/README.md
7-7: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
9-9: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
11-11: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
13-13: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
15-15: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
20-20: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
88-88: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
89-89: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
90-90: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
96-96: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
109-109: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
113-113: Emphasis used instead of a heading
(MD036, no-emphasis-as-heading)
🪛 Ruff (0.13.3)
examples/ray-kft-v1/scripts/ray_sdg_job.py
1-1: Shebang is present but file is not executable
(EXE001)
82-82: Consider moving this statement to an else block
(TRY300)
83-83: Do not catch blind exception: Exception
(BLE001)
109-109: Do not catch blind exception: Exception
(BLE001)
125-125: Do not catch blind exception: Exception
(BLE001)
158-158: Probable insecure usage of temporary file or directory: "/tmp/synthetic_data"
(S108)
158-158: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
202-202: Do not catch blind exception: Exception
(BLE001)
264-264: Consider moving this statement to an else block
(TRY300)
266-266: Do not catch blind exception: Exception
(BLE001)
276-276: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
276-276: Create your own exception
(TRY002)
276-276: Avoid specifying long messages outside the exception class
(TRY003)
304-304: Probable insecure usage of temporary file or directory: "/tmp/.cache"
(S108)
308-308: Probable insecure usage of temporary file or directory: "/tmp/.cache"
(S108)
348-348: Do not catch blind exception: Exception
(BLE001)
371-371: String contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF001)
375-375: String contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF001)
434-434: Do not use bare except
(E722)
434-435: try-except-pass detected, consider logging the exception
(S110)
439-439: Do not catch blind exception: Exception
(BLE001)
529-529: Unused method argument: question
(ARG002)
531-531: String contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF001)
548-548: String contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF001)
557-557: Do not use bare except
(E722)
557-558: try-except-continue detected, consider logging the exception
(S112)
683-683: Do not use bare except
(E722)
683-684: try-except-pass detected, consider logging the exception
(S110)
709-709: Do not catch blind exception: Exception
(BLE001)
805-805: Probable insecure usage of temporary file or directory: "/tmp/synthetic_data"
(S108)
877-877: f-string without any placeholders
Remove extraneous f prefix
(F541)
909-909: Unused function argument: frame
(ARG001)
980-980: f-string without any placeholders
Remove extraneous f prefix
(F541)
1007-1007: Local variable num_cpus is assigned to but never used
Remove assignment to unused variable num_cpus
(F841)
1013-1013: f-string without any placeholders
Remove extraneous f prefix
(F541)
1021-1021: f-string without any placeholders
Remove extraneous f prefix
(F541)
1048-1048: f-string without any placeholders
Remove extraneous f prefix
(F541)
1053-1053: f-string without any placeholders
Remove extraneous f prefix
(F541)
1118-1118: Do not catch blind exception: Exception
(BLE001)
1168-1168: f-string without any placeholders
Remove extraneous f prefix
(F541)
1173-1173: f-string without any placeholders
Remove extraneous f prefix
(F541)
examples/ray-kft-v1/4_test_inference.ipynb
49-49: f-string without any placeholders
Remove extraneous f prefix
(F541)
61-61: Do not catch blind exception: Exception
(BLE001)
94-94: Do not catch blind exception: Exception
(BLE001)
105-105: Do not catch blind exception: Exception
(BLE001)
118-118: Do not catch blind exception: Exception
(BLE001)
163-163: Do not catch blind exception: Exception
(BLE001)
174-174: f-string without any placeholders
Remove extraneous f prefix
(F541)
230-230: Consider moving this statement to an else block
(TRY300)
231-231: Do not catch blind exception: Exception
(BLE001)
310-310: Do not catch blind exception: Exception
(BLE001)
330-330: Do not catch blind exception: Exception
(BLE001)
337-337: f-string without any placeholders
Remove extraneous f prefix
(F541)
343-343: f-string without any placeholders
Remove extraneous f prefix
(F541)
351-351: f-string without any placeholders
Remove extraneous f prefix
(F541)
355-355: f-string without any placeholders
Remove extraneous f prefix
(F541)
367-367: f-string without any placeholders
Remove extraneous f prefix
(F541)
373-373: f-string without any placeholders
Remove extraneous f prefix
(F541)
389-389: f-string without any placeholders
Remove extraneous f prefix
(F541)
examples/ray-kft-v1/scripts/kft_granite_training.py
41-41: Local variable dataset_batch_size is assigned to but never used
Remove assignment to unused variable dataset_batch_size
(F841)
83-83: Local variable dataset_config is assigned to but never used
Remove assignment to unused variable dataset_config
(F841)
84-84: Local variable dataset_train_split is assigned to but never used
Remove assignment to unused variable dataset_train_split
(F841)
85-85: Local variable dataset_test_split is assigned to but never used
Remove assignment to unused variable dataset_test_split
(F841)
86-86: Local variable dataset_text_field is assigned to but never used
Remove assignment to unused variable dataset_text_field
(F841)
87-87: Local variable dataset_kwargs is assigned to but never used
Remove assignment to unused variable dataset_kwargs
(F841)
93-93: Local variable script_args is assigned to but never used
Remove assignment to unused variable script_args
(F841)
164-164: Avoid specifying long messages outside the exception class
(TRY003)
243-243: f-string without any placeholders
Remove extraneous f prefix
(F541)
245-245: f-string without any placeholders
Remove extraneous f prefix
(F541)
examples/ray-kft-v1/1_ray_sdg.ipynb
6-6: Possible hardcoded password assigned to: "token"
(S105)
examples/ray-kft-v1/2_kft_training.ipynb
7-7: Possible hardcoded password assigned to: "token"
(S105)
30-30: Undefined name training_parameters
(F821)
70-70: f-string without any placeholders
Remove extraneous f prefix
(F541)
🔇 Additional comments (3)
examples/ray-kft-v1/4_test_inference.ipynb (3)
32-34: Verify hard-coded paths match deployment environment.The notebook uses hard-coded paths to shared storage (lines 32-34, 365) that differ from the sample dataset location in this PR (
examples/ray-kft-v1/dataset/sample_synthetic_dataset.json). Ensure these paths match your actual deployment environment:
- Trained model:
/opt/app-root/src/shared/models/granite-3.1-2b-instruct-synthetic2- Test data:
/opt/app-root/src/shared/synthetic_data_v2/synthetic_dataset.json- Base model cache:
/opt/app-root/src/shared/huggingface_cache/...Consider parameterizing these paths or documenting the expected directory structure in the README to improve portability.
Also applies to: 365-365
548-564: LGTM! Proper resource cleanup.The cleanup logic appropriately handles model deletion and GPU memory management. Using
locals()to check for variable existence before deletion is a safe approach for notebook environments where cells may be run out of order.
304-330: LGTM! Inference logic is well-structured.The
generate_responsefunction correctly:
- Applies chat template for proper formatting
- Handles tokenization with truncation
- Uses appropriate generation parameters (temperature=0.7, top_p=0.9) for math problem solving
- Properly removes the input prompt from the output
The sampling parameters balance creativity with consistency, which is suitable for generating mathematical explanations.
| " if isinstance(data, list):\n", | ||
| " total_samples = len(data)\n", | ||
| " avg_quality = sum(item.get('overall_quality', 0) for item in data) / total_samples if total_samples > 0 else 0\n", | ||
| " sample = data[0] if data else None\n", | ||
| " \n", | ||
| " print(f\" Dataset found: {total_samples} samples\")\n", | ||
| " print(f\" Avg quality: {avg_quality:.2f} \\n Source: {sample.get('source', 'N/A') if sample else 'N/A'}\")\n", | ||
| " \n", | ||
| " # Show sample\n", | ||
| " if sample:\n", | ||
| " print(f\" Sample Question -> {sample['question']}\")\n", | ||
| " print(f\" Sample Answer -> {sample['answer']}\")\n", | ||
| " \n", | ||
| " print(\"\\n Ready for training!\") \n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample referenced before assignment for dict datasets
When the synthetic data is stored as a dict with train / test splits (the default output of save_final_dataset), the else branch never defines sample, but you still access if sample: a few lines later. That yields an UnboundLocalError, so the readiness check cell fails even though the dataset is valid. Initialise sample = None before the if and set it inside both branches, e.g.:
- if isinstance(data, list):
+ sample = None
+ if isinstance(data, list):
...
sample = data[0] if data else None
else:
train = data.get("train", [])
test = data.get("test", [])
total_samples = len(train) + len(test)
avg_quality = ...
sample = (train or test)[0] if (train or test) else NoneThat keeps the inspection working across both list and split formats.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/ray-kft-v1/1_ray_sdg.ipynb around lines 422 to 435, the variable
`sample` is only set inside the `if isinstance(data, list):` branch so when
`data` is a dict (train/test splits) `sample` is undefined and later `if
sample:` raises UnboundLocalError; initialize `sample = None` before the branch
and ensure you set `sample` in both branches (e.g., pick a sample from
`data['train']` or `data.get('train', [])` when `data` is a dict) so the
subsequent sample prints and readiness message work for both list and split
formats.
| "def load_test_data(file_path: str, num_samples: int = 5) -> List[Dict[str, Any]]:\n", | ||
| " \"\"\"Load test samples from synthetic dataset\"\"\"\n", | ||
| " try:\n", | ||
| " with open(file_path, 'r') as f:\n", | ||
| " data = json.load(f)\n", | ||
| " \n", | ||
| " # Take a subset for testing\n", | ||
| " if len(data) > num_samples:\n", | ||
| " # Take samples from different parts of the dataset\n", | ||
| " step = len(data) // num_samples\n", | ||
| " test_samples = [data[i] for i in range(0, len(data), step)][:num_samples]\n", | ||
| " else:\n", | ||
| " test_samples = data[:num_samples]\n", | ||
| " \n", | ||
| " return test_samples\n", | ||
| " except Exception as e:\n", | ||
| " print(f\"Error loading test data: {e}\")\n", | ||
| " return []\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Data format mismatch will cause runtime failure.
The load_test_data function assumes the JSON file contains a flat list at the root level (line 336: data = json.load(f) followed by list operations like len(data) and data[i]). However, the sample dataset at examples/ray-kft-v1/dataset/sample_synthetic_dataset.json has a structured format with train, test, and metadata keys.
This mismatch will cause a runtime error when attempting to get the length or index into a dictionary.
Update the function to handle the structured format:
def load_test_data(file_path: str, num_samples: int = 5) -> List[Dict[str, Any]]:
"""Load test samples from synthetic dataset"""
try:
with open(file_path, 'r') as f:
data = json.load(f)
+ # Handle structured format with train/test partitions
+ if isinstance(data, dict) and 'test' in data:
+ data = data['test']
+ elif isinstance(data, dict) and 'train' in data:
+ data = data['train']
+
# Take a subset for testing
if len(data) > num_samples:
# Take samples from different parts of the dataset
step = len(data) // num_samples
test_samples = [data[i] for i in range(0, len(data), step)][:num_samples]
else:
test_samples = data[:num_samples]
return test_samples
except Exception as e:
print(f"Error loading test data: {e}")
return []📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "def load_test_data(file_path: str, num_samples: int = 5) -> List[Dict[str, Any]]:\n", | |
| " \"\"\"Load test samples from synthetic dataset\"\"\"\n", | |
| " try:\n", | |
| " with open(file_path, 'r') as f:\n", | |
| " data = json.load(f)\n", | |
| " \n", | |
| " # Take a subset for testing\n", | |
| " if len(data) > num_samples:\n", | |
| " # Take samples from different parts of the dataset\n", | |
| " step = len(data) // num_samples\n", | |
| " test_samples = [data[i] for i in range(0, len(data), step)][:num_samples]\n", | |
| " else:\n", | |
| " test_samples = data[:num_samples]\n", | |
| " \n", | |
| " return test_samples\n", | |
| " except Exception as e:\n", | |
| " print(f\"Error loading test data: {e}\")\n", | |
| " return []\n", | |
| def load_test_data(file_path: str, num_samples: int = 5) -> List[Dict[str, Any]]: | |
| """Load test samples from synthetic dataset""" | |
| try: | |
| with open(file_path, 'r') as f: | |
| data = json.load(f) | |
| # Handle structured format with train/test partitions | |
| if isinstance(data, dict) and 'test' in data: | |
| data = data['test'] | |
| elif isinstance(data, dict) and 'train' in data: | |
| data = data['train'] | |
| # Take a subset for testing | |
| if len(data) > num_samples: | |
| # Take samples from different parts of the dataset | |
| step = len(data) // num_samples | |
| test_samples = [data[i] for i in range(0, len(data), step)][:num_samples] | |
| else: | |
| test_samples = data[:num_samples] | |
| return test_samples | |
| except Exception as e: | |
| print(f"Error loading test data: {e}") | |
| return [] |
🧰 Tools
🪛 Ruff (0.13.3)
337-337: f-string without any placeholders
Remove extraneous f prefix
(F541)
343-343: f-string without any placeholders
Remove extraneous f prefix
(F541)
| "source": "ray_sdg_qwen", | ||
| "seed_id": 0, | ||
| "variation_id": 0, | ||
| "difficulty": "Easy", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Standardize difficulty field capitalization.
The difficulty field has inconsistent capitalization: "Easy" (line 10), "easy" (line 33), and "medium" (line 56). This inconsistency is reflected in the metadata's difficulty_distribution (lines 80-82) where "Easy" and "easy" are counted as separate categories.
This could cause issues in downstream processing such as filtering, grouping, or visualization.
Apply this standardization (choose lowercase for consistency):
- "difficulty": "Easy",
+ "difficulty": "easy",And update metadata:
"difficulty_distribution": {
- "Easy": 1,
- "easy": 1,
+ "easy": 2,
"medium": 1
},Also applies to: 33-33, 56-56, 80-82
🤖 Prompt for AI Agents
In examples/ray-kft-v1/dataset/sample_synthetic_dataset.json around lines 10,
33, 56 and 80-82, the difficulty values are inconsistently capitalized ("Easy",
"easy", "medium"); normalize all individual item "difficulty" fields to
lowercase (e.g., "easy", "medium") and update the
metadata.difficulty_distribution keys to the corresponding lowercase names,
merging counts so the distribution reflects aggregated lowercase categories.
| print(f"HuggingFace cache directory: {os.environ['HF_HOME']}") | ||
| os.makedirs(os.environ['HF_HOME'], exist_ok=True) | ||
| os.makedirs(os.environ['HF_DATASETS_CACHE'], exist_ok=True) | ||
| os.makedirs('/shared/models', exist_ok=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard HuggingFace cache env vars
Accessing os.environ['HF_HOME'] / ['HF_DATASETS_CACHE'] will raise KeyError whenever those variables aren’t pre-set (common for local testing and many Kubeflow images). That aborts the training before any logic runs. Please fall back to sensible defaults, e.g.:
- print(f"HuggingFace cache directory: {os.environ['HF_HOME']}")
- os.makedirs(os.environ['HF_HOME'], exist_ok=True)
- os.makedirs(os.environ['HF_DATASETS_CACHE'], exist_ok=True)
+ hf_home = os.environ.get("HF_HOME", "/shared/huggingface_cache")
+ hf_datasets_cache = os.environ.get("HF_DATASETS_CACHE", os.path.join(hf_home, "datasets"))
+ print(f"HuggingFace cache directory: {hf_home}")
+ os.makedirs(hf_home, exist_ok=True)
+ os.makedirs(hf_datasets_cache, exist_ok=True)Same idea applies to any later code using those env vars.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| print(f"HuggingFace cache directory: {os.environ['HF_HOME']}") | |
| os.makedirs(os.environ['HF_HOME'], exist_ok=True) | |
| os.makedirs(os.environ['HF_DATASETS_CACHE'], exist_ok=True) | |
| os.makedirs('/shared/models', exist_ok=True) | |
| hf_home = os.environ.get("HF_HOME", "/shared/huggingface_cache") | |
| hf_datasets_cache = os.environ.get( | |
| "HF_DATASETS_CACHE", | |
| os.path.join(hf_home, "datasets") | |
| ) | |
| print(f"HuggingFace cache directory: {hf_home}") | |
| os.makedirs(hf_home, exist_ok=True) | |
| os.makedirs(hf_datasets_cache, exist_ok=True) | |
| os.makedirs('/shared/models', exist_ok=True) |
🤖 Prompt for AI Agents
In examples/ray-kft-v1/scripts/kft_granite_training.py around lines 13 to 16,
the code directly indexes os.environ['HF_HOME'] and
os.environ['HF_DATASETS_CACHE'] which will raise KeyError if those env vars
aren’t set; change to use os.environ.get with sensible defaults (e.g.
HF_HOME=os.environ.get('HF_HOME', os.path.expanduser('~/.cache/huggingface')),
HF_DATASETS_CACHE=os.environ.get('HF_DATASETS_CACHE',
os.path.expanduser('~/.cache/huggingface/datasets'))), print those resolved
values, and call os.makedirs on the resolved variables (and '/shared/models')
with exist_ok=True; also update any later code to reference these resolved
variables rather than direct os.environ indexing.
| print("Loading synthetic dataset...") | ||
| dataset_paths = [ | ||
| "/shared/synthetic_data/synthetic_dataset.json", | ||
| "/shared/synthetic_data/final_synthetic_dataset.json" | ||
| ] | ||
|
|
||
| synthetic_data = None | ||
| for path in dataset_paths: | ||
| try: | ||
| with open(path, "r") as f: | ||
| synthetic_data = json.load(f) | ||
| print(f"Loaded synthetic dataset from: {path}") | ||
| break | ||
| except FileNotFoundError: | ||
| continue | ||
|
|
||
| if synthetic_data is None: | ||
| print("Synthetic dataset not found in any expected location:") | ||
| for path in dataset_paths: | ||
| print(f" - {path}") | ||
| print("Please run Ray preprocessing first.") | ||
| raise FileNotFoundError("No synthetic dataset found") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Training job can’t locate the generated dataset
dataset_paths is limited to /shared/synthetic_data/..., but the Ray SDG script in this PR defaults to /tmp/synthetic_data, and your notebook example actually writes to /shared/synthetic_data_v2. Following either path causes this loader to raise FileNotFoundError, even though the data exists. Please either honour a configurable path (e.g. parameters['dataset_path']) or extend the search to match the Ray job outputs so the out-of-the-box flow works:
- dataset_paths = [
- "/shared/synthetic_data/synthetic_dataset.json",
- "/shared/synthetic_data/final_synthetic_dataset.json"
- ]
+ dataset_root = parameters.get("dataset_path") if parameters else None
+ candidate_roots = [
+ dataset_root,
+ "/shared/synthetic_data_v2",
+ "/shared/synthetic_data",
+ "/tmp/synthetic_data",
+ ]
+ dataset_paths = [
+ os.path.join(root, "synthetic_dataset.json")
+ for root in candidate_roots
+ if root
+ ] + [
+ os.path.join(root, "final_synthetic_dataset.json")
+ for root in candidate_roots
+ if root
+ ]That keeps the training notebook and CLI instructions aligned with the generation step.
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.13.3)
164-164: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In examples/ray-kft-v1/scripts/kft_granite_training.py around lines 143 to 164,
the loader only checks /shared/synthetic_data/* which doesn't match other
locations used by the Ray SDG or notebook; update the code to first check for a
configurable dataset path (e.g. parameters.get('dataset_path') or an environment
variable) and fall back to an expanded list of candidate paths such as
/tmp/synthetic_data, /shared/synthetic_data, and /shared/synthetic_data_v2;
iterate those paths, attempt to open and json.load each, log the exact path that
succeeds, and raise FileNotFoundError only after all candidates (including the
configurable path if provided) fail.
| items = batch["items"] | ||
| total_generated += len(items) | ||
| processed_batches += 1 | ||
|
|
||
| batch_seed_ids = set() | ||
| for item in items: | ||
| if item["overall_quality"] >= quality_threshold: | ||
| high_quality_count += 1 | ||
| quality_scores.append(item["overall_quality"]) | ||
| batch_seed_ids.add(item["seed_id"]) | ||
|
|
||
| if batch_seed_ids: | ||
| checkpoint_manager.save_checkpoint(batch_seed_ids, total_expected_seeds) | ||
|
|
||
| print(f"Processed batch {processed_batches}: {len(items)} items, {len(batch_seed_ids)} seeds") | ||
|
|
||
| # Check for progress stall | ||
| if total_generated == last_total_generated: | ||
| batches_without_progress += 1 | ||
| if batches_without_progress >= max_batches_without_progress: | ||
| print(f"No progress for {max_batches_without_progress} batches, stopping pipeline...") | ||
| break | ||
| else: | ||
| batches_without_progress = 0 | ||
| last_total_generated = total_generated | ||
|
|
||
| except KeyboardInterrupt: | ||
| print("Pipeline interrupted by user") | ||
| except Exception as e: | ||
| print(f"Pipeline error: {e}") | ||
| print("Saving progress and continuing...") | ||
|
|
||
| end_time = time.time() | ||
| processing_time = end_time - start_time | ||
|
|
||
| # Create metadata | ||
| metadata = { | ||
| "total_generated": total_generated, | ||
| "high_quality_count": high_quality_count, | ||
| "quality_pass_rate": (high_quality_count / total_generated * 100) if total_generated > 0 else 0, | ||
| "quality_threshold": quality_threshold, | ||
| "avg_quality_score": sum(quality_scores) / len(quality_scores) if quality_scores else 0, | ||
| "processing_time_seconds": processing_time, | ||
| "model_used": MODEL_NAME, | ||
| "generation_method": "ray_data_distributed", | ||
| "ray_data_features": [ | ||
| "map_batches_inference", | ||
| "automatic_scaling", | ||
| "fault_tolerance", | ||
| "streaming_processing", | ||
| "quality_filtering" | ||
| ] | ||
| } | ||
|
|
||
| print("\n" + "="*60) | ||
| print("RAY DATA SDG PIPELINE SUMMARY") | ||
| print("="*60) | ||
| print(f"Total problems generated: {total_generated}") | ||
| print(f"High quality problems: {high_quality_count}") | ||
| print(f"Quality pass rate: {metadata['quality_pass_rate']:.1f}%") | ||
| print(f"Average quality score: {metadata['avg_quality_score']:.3f}") | ||
| print(f"Processing time: {processing_time:.1f} seconds") | ||
| print(f"Throughput: {total_generated/processing_time:.2f} problems/second") | ||
|
|
||
| if cluster_info['total_gpus'] > 0: | ||
| gpu_efficiency = total_generated / (cluster_info['total_gpus'] * processing_time) | ||
| print(f"GPU efficiency: {gpu_efficiency:.2f} problems/GPU/second") | ||
| print(f"GPU utilization: {resource_config['total_workers']}/{cluster_info['total_gpus']} GPUs used") | ||
|
|
||
| print("="*60) | ||
|
|
||
| # Final checkpoint save | ||
| checkpoint_manager.save_checkpoint(checkpoint_manager.processed_seeds, total_expected_seeds, force=True) | ||
|
|
||
| # Save final results | ||
| if checkpoint_manager.current_data: | ||
| save_final_dataset(checkpoint_manager.current_data, args.output_path, metadata) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Final dataset export never triggers
checkpoint_manager.current_data is never populated during a fresh run, because we never call CheckpointManager.save_batch_data. Consequently the final export (Lines 1164-1167) is skipped and the pipeline finishes without producing the final_synthetic_dataset.json, even though batches were generated and written out by the workers. Persist the consolidated batch items through the checkpoint manager when iterating the pipeline so current_data stays in sync before the final save. One simple way to do that:
for item in items:
if item["overall_quality"] >= quality_threshold:
high_quality_count += 1
quality_scores.append(item["overall_quality"])
batch_seed_ids.add(item["seed_id"])
if batch_seed_ids:
checkpoint_manager.save_checkpoint(batch_seed_ids, total_expected_seeds)
+ if items:
+ checkpoint_manager.save_batch_data(items)You can then consider dropping the ad-hoc _save_batch_data inside ModelInferenceCallable to avoid double writes.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| items = batch["items"] | |
| total_generated += len(items) | |
| processed_batches += 1 | |
| batch_seed_ids = set() | |
| for item in items: | |
| if item["overall_quality"] >= quality_threshold: | |
| high_quality_count += 1 | |
| quality_scores.append(item["overall_quality"]) | |
| batch_seed_ids.add(item["seed_id"]) | |
| if batch_seed_ids: | |
| checkpoint_manager.save_checkpoint(batch_seed_ids, total_expected_seeds) | |
| print(f"Processed batch {processed_batches}: {len(items)} items, {len(batch_seed_ids)} seeds") | |
| # Check for progress stall | |
| if total_generated == last_total_generated: | |
| batches_without_progress += 1 | |
| if batches_without_progress >= max_batches_without_progress: | |
| print(f"No progress for {max_batches_without_progress} batches, stopping pipeline...") | |
| break | |
| else: | |
| batches_without_progress = 0 | |
| last_total_generated = total_generated | |
| except KeyboardInterrupt: | |
| print("Pipeline interrupted by user") | |
| except Exception as e: | |
| print(f"Pipeline error: {e}") | |
| print("Saving progress and continuing...") | |
| end_time = time.time() | |
| processing_time = end_time - start_time | |
| # Create metadata | |
| metadata = { | |
| "total_generated": total_generated, | |
| "high_quality_count": high_quality_count, | |
| "quality_pass_rate": (high_quality_count / total_generated * 100) if total_generated > 0 else 0, | |
| "quality_threshold": quality_threshold, | |
| "avg_quality_score": sum(quality_scores) / len(quality_scores) if quality_scores else 0, | |
| "processing_time_seconds": processing_time, | |
| "model_used": MODEL_NAME, | |
| "generation_method": "ray_data_distributed", | |
| "ray_data_features": [ | |
| "map_batches_inference", | |
| "automatic_scaling", | |
| "fault_tolerance", | |
| "streaming_processing", | |
| "quality_filtering" | |
| ] | |
| } | |
| print("\n" + "="*60) | |
| print("RAY DATA SDG PIPELINE SUMMARY") | |
| print("="*60) | |
| print(f"Total problems generated: {total_generated}") | |
| print(f"High quality problems: {high_quality_count}") | |
| print(f"Quality pass rate: {metadata['quality_pass_rate']:.1f}%") | |
| print(f"Average quality score: {metadata['avg_quality_score']:.3f}") | |
| print(f"Processing time: {processing_time:.1f} seconds") | |
| print(f"Throughput: {total_generated/processing_time:.2f} problems/second") | |
| if cluster_info['total_gpus'] > 0: | |
| gpu_efficiency = total_generated / (cluster_info['total_gpus'] * processing_time) | |
| print(f"GPU efficiency: {gpu_efficiency:.2f} problems/GPU/second") | |
| print(f"GPU utilization: {resource_config['total_workers']}/{cluster_info['total_gpus']} GPUs used") | |
| print("="*60) | |
| # Final checkpoint save | |
| checkpoint_manager.save_checkpoint(checkpoint_manager.processed_seeds, total_expected_seeds, force=True) | |
| # Save final results | |
| if checkpoint_manager.current_data: | |
| save_final_dataset(checkpoint_manager.current_data, args.output_path, metadata) | |
| if batch_seed_ids: | |
| checkpoint_manager.save_checkpoint(batch_seed_ids, total_expected_seeds) | |
| if items: | |
| checkpoint_manager.save_batch_data(items) | |
| print(f"Processed batch {processed_batches}: {len(items)} items, {len(batch_seed_ids)} seeds") |
🧰 Tools
🪛 Ruff (0.13.3)
1118-1118: Do not catch blind exception: Exception
(BLE001)
|
Hi @abhijeet-dhumal this looks great! I could be a few days before I have time to fully review it. Can I ask the context or the ask that drove this work? Is there an issues/jiras for tracking? |
Description
RHOAIENG-34071
How Has This Been Tested?
Merge criteria:
Summary by CodeRabbit