-
Notifications
You must be signed in to change notification settings - Fork 70
test: initial implementation of SDK e2e #488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| bin/* | ||
| .vscode/* |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| /* | ||
| Copyright 2025. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
|
|
||
| package trainer | ||
|
|
||
| import ( | ||
| "testing" | ||
|
|
||
| . "github.com/opendatahub-io/distributed-workloads/tests/common" | ||
| sdktests "github.com/opendatahub-io/distributed-workloads/tests/trainer/sdk_tests" | ||
| ) | ||
|
|
||
| func TestKubeflowSDK_Sanity(t *testing.T) { | ||
| Tags(t, Sanity) | ||
| sdktests.RunFashionMnistCpuDistributedTraining(t) | ||
| // ADD MORE SANITY TESTS HERE | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,275 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 1, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:19:46.917723Z", | ||
| "iopub.status.busy": "2025-09-03T13:19:46.917308Z", | ||
| "iopub.status.idle": "2025-09-03T13:19:46.935181Z", | ||
| "shell.execute_reply": "2025-09-03T13:19:46.934697Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:19:46.917698Z" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "def train_fashion_mnist():\n", | ||
| " import os\n", | ||
| "\n", | ||
| " import torch\n", | ||
| " import torch.distributed as dist\n", | ||
| " import torch.nn.functional as F\n", | ||
| " from torch import nn\n", | ||
| " from torch.utils.data import DataLoader, DistributedSampler\n", | ||
| " from torchvision import datasets, transforms\n", | ||
| "\n", | ||
| " # Define the PyTorch CNN model to be trained\n", | ||
| " class Net(nn.Module):\n", | ||
| " def __init__(self):\n", | ||
| " super(Net, self).__init__()\n", | ||
| " self.conv1 = nn.Conv2d(1, 20, 5, 1)\n", | ||
| " self.conv2 = nn.Conv2d(20, 50, 5, 1)\n", | ||
| " self.fc1 = nn.Linear(4 * 4 * 50, 500)\n", | ||
| " self.fc2 = nn.Linear(500, 10)\n", | ||
| "\n", | ||
| " def forward(self, x):\n", | ||
| " x = F.relu(self.conv1(x))\n", | ||
| " x = F.max_pool2d(x, 2, 2)\n", | ||
| " x = F.relu(self.conv2(x))\n", | ||
| " x = F.max_pool2d(x, 2, 2)\n", | ||
| " x = x.view(-1, 4 * 4 * 50)\n", | ||
| " x = F.relu(self.fc1(x))\n", | ||
| " x = self.fc2(x)\n", | ||
| " return F.log_softmax(x, dim=1)\n", | ||
| "\n", | ||
| " # Use NCCL if a GPU is available, otherwise use Gloo as communication backend.\n", | ||
| " device, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n", | ||
| " print(f\"Using Device: {device}, Backend: {backend}\")\n", | ||
| "\n", | ||
| " # Setup PyTorch distributed.\n", | ||
| " local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n", | ||
| " dist.init_process_group(backend=backend)\n", | ||
| " print(\n", | ||
| " \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n", | ||
| " dist.get_world_size(),\n", | ||
| " dist.get_rank(),\n", | ||
| " local_rank,\n", | ||
| " )\n", | ||
| " )\n", | ||
| "\n", | ||
| " # Create the model and load it into the device.\n", | ||
| " device = torch.device(f\"{device}:{local_rank}\")\n", | ||
| " model = nn.parallel.DistributedDataParallel(Net().to(device))\n", | ||
| " optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n", | ||
| "\n", | ||
| " \n", | ||
| " # Download FashionMNIST dataset only on local_rank=0 process.\n", | ||
| " if local_rank == 0:\n", | ||
| " dataset = datasets.FashionMNIST(\n", | ||
| " \"./data\",\n", | ||
| " train=True,\n", | ||
| " download=True,\n", | ||
| " transform=transforms.Compose([transforms.ToTensor()]),\n", | ||
| " )\n", | ||
| " dist.barrier()\n", | ||
| " dataset = datasets.FashionMNIST(\n", | ||
| " \"./data\",\n", | ||
| " train=True,\n", | ||
| " download=False,\n", | ||
| " transform=transforms.Compose([transforms.ToTensor()]),\n", | ||
| " )\n", | ||
| "\n", | ||
| "\n", | ||
| " # Shard the dataset accross workers.\n", | ||
| " train_loader = DataLoader(\n", | ||
| " dataset,\n", | ||
| " batch_size=100,\n", | ||
| " sampler=DistributedSampler(dataset)\n", | ||
| " )\n", | ||
| "\n", | ||
| " # TODO(astefanutti): add parameters to the training function\n", | ||
| " dist.barrier()\n", | ||
| " for epoch in range(1, 3):\n", | ||
| " model.train()\n", | ||
| "\n", | ||
| " # Iterate over mini-batches from the training set\n", | ||
| " for batch_idx, (inputs, labels) in enumerate(train_loader):\n", | ||
| " # Copy the data to the GPU device if available\n", | ||
| " inputs, labels = inputs.to(device), labels.to(device)\n", | ||
| " # Forward pass\n", | ||
| " outputs = model(inputs)\n", | ||
| " loss = F.nll_loss(outputs, labels)\n", | ||
| " # Backward pass\n", | ||
| " optimizer.zero_grad()\n", | ||
| " loss.backward()\n", | ||
| " optimizer.step()\n", | ||
| "\n", | ||
| " if batch_idx % 10 == 0 and dist.get_rank() == 0:\n", | ||
| " print(\n", | ||
| " \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n", | ||
| " epoch,\n", | ||
| " batch_idx * len(inputs),\n", | ||
| " len(train_loader.dataset),\n", | ||
| " 100.0 * batch_idx / len(train_loader),\n", | ||
| " loss.item(),\n", | ||
| " )\n", | ||
| " )\n", | ||
| "\n", | ||
| " # Wait for the distributed training to complete\n", | ||
| " dist.barrier()\n", | ||
| " if dist.get_rank() == 0:\n", | ||
| " print(\"Training is finished\")\n", | ||
| "\n", | ||
| " # Finally clean up PyTorch distributed\n", | ||
| " dist.destroy_process_group()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:19:49.832393Z", | ||
| "iopub.status.busy": "2025-09-03T13:19:49.832117Z", | ||
| "iopub.status.idle": "2025-09-03T13:19:51.924613Z", | ||
| "shell.execute_reply": "2025-09-03T13:19:51.924264Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:19:49.832371Z" | ||
| }, | ||
| "pycharm": { | ||
| "name": "#%%\n" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from kubeflow.trainer import CustomTrainer, TrainerClient\n", | ||
| "\n", | ||
| "client = TrainerClient()\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "for runtime in client.list_runtimes():\n", | ||
| " print(runtime)\n", | ||
| " if runtime.name == \"universal\": # Update to actual universal image runtime once available\n", | ||
| " torch_runtime = runtime" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this approach instead of getting universal runtime by
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That notebook is basically a copy from usptream test that SDK upstream relies on. I wanted to keep it as close as possible to original but I guess there's no harm in moving to get_runtime. Thanks Karel! |
||
| ] | ||
|
Comment on lines
+156
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainHardcoded runtime name needs to be resolved. The comment on line 158 indicates this uses a placeholder runtime name "universal" that needs updating. This conflicts with the hardcoded runtime checks in The PR description mentions needing a custom clusterTrainingRuntime named "universal", but the cluster prep utility checks for different names. Ensure consistency between:
Consider parameterizing the runtime name or documenting the expected setup more clearly. Length of output: 217 🏁 Script executed: #!/bin/bash
# Examine utils_cluster_prep.go around lines 53-54
fd -e go -name 'utils_cluster_prep.go' --exec cat {} \; | head -70 | tail -25Length of output: 230 🏁 Script executed: #!/bin/bash
# Search for runtime-related configurations and definitions
rg -n "runtime" -i --type go -A2 -B2 | head -100Length of output: 7893 Length of output: 37 🏁 Script executed: #!/bin/bash
# Once found, extract notebook content and look for "universal" runtime reference
find . -name "mnist.ipynb" -path "*/tests/trainer/resources/*" -type f -exec python3 << 'EOF' {} \;
import json
import sys
try:
with open(sys.argv[1], 'r') as f:
notebook = json.load(f)
# Search through all cells for "universal" string
for cell_idx, cell in enumerate(notebook.get('cells', [])):
source = cell.get('source', [])
if isinstance(source, str):
source_lines = source.split('\n')
else:
source_lines = source
for line_idx, line in enumerate(source_lines):
if 'universal' in line.lower() or 'torch_runtime' in line.lower():
print(f"Cell {cell_idx}, Line {line_idx}: {line}")
except Exception as e:
print(f"Error processing {sys.argv[1]}: {e}", file=sys.stderr)
EOFLength of output: 218 🏁 Script executed: #!/bin/bash
# Also search codebase for references to "universal" runtime
rg -n "universal" --type goLength of output: 113 Length of output: 108 🏁 Script executed: #!/bin/bash
# Search for torch_runtime usage in notebook
rg -n "torch_runtime\|runtime" ./tests/trainer/resources/mnist.ipynb | head -30Length of output: 0 🏁 Script executed: #!/bin/bash
# Extract notebook with cat and parse JSON properly
cat ./tests/trainer/resources/mnist.ipynb | python3 << 'PYTHON_EOF'
import json
import sys
try:
nb = json.load(sys.stdin)
for idx, cell in enumerate(nb.get('cells', [])):
source = cell.get('source', [])
if isinstance(source, list):
content = ''.join(source)
else:
content = source
# Look for runtime-related content
if 'universal' in content.lower() or 'torch_runtime' in content.lower():
lines = content.split('\n')
for line_num, line in enumerate(lines):
if 'universal' in line.lower() or 'torch_runtime' in line.lower():
print(f"Cell {idx}: {line}")
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
PYTHON_EOFLength of output: 48 Update notebook runtime selection to match validated test runtimes. The notebook hardcodes a check for Align the notebook to use one of the currently available runtimes (e.g., "torch-cuda-241") or parameterize it with a fallback mechanism. 🤖 Prompt for AI Agents |
||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:19:56.525591Z", | ||
| "iopub.status.busy": "2025-09-03T13:19:56.524936Z", | ||
| "iopub.status.idle": "2025-09-03T13:19:56.721404Z", | ||
| "shell.execute_reply": "2025-09-03T13:19:56.720565Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:19:56.525536Z" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "job_name = client.train(\n", | ||
| " trainer=CustomTrainer(\n", | ||
| " func=train_fashion_mnist,\n", | ||
| " num_nodes=2,\n", | ||
| " resources_per_node={\n", | ||
| " \"cpu\": 2,\n", | ||
| " \"memory\": \"8Gi\",\n", | ||
| " },\n", | ||
| " packages_to_install=[\"torchvision\"],\n", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be good to install specific version, to make sure that a future upgrade doesn't break test.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes I think it makes sense. Thanks. |
||
| " ),\n", | ||
| " runtime=torch_runtime,\n", | ||
| ")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:20:01.378158Z", | ||
| "iopub.status.busy": "2025-09-03T13:20:01.377707Z", | ||
| "iopub.status.idle": "2025-09-03T13:20:12.713960Z", | ||
| "shell.execute_reply": "2025-09-03T13:20:12.713295Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:20:01.378130Z" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Wait for the running status.\n", | ||
| "client.wait_for_job_status(name=job_name, status={\"Running\"})" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:20:24.045774Z", | ||
| "iopub.status.busy": "2025-09-03T13:20:24.045480Z", | ||
| "iopub.status.idle": "2025-09-03T13:20:24.772877Z", | ||
| "shell.execute_reply": "2025-09-03T13:20:24.772178Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:20:24.045755Z" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "for c in client.get_job(name=job_name).steps:\n", | ||
| " print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\\n\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:20:26.729486Z", | ||
| "iopub.status.busy": "2025-09-03T13:20:26.728951Z", | ||
| "iopub.status.idle": "2025-09-03T13:20:29.596510Z", | ||
| "shell.execute_reply": "2025-09-03T13:20:29.594741Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:20:26.729446Z" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "for logline in client.get_job_logs(job_name, follow=True):\n", | ||
| " print(logline)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "client.delete_job(job_name)" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "Python 3 (ipykernel)", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.11.13" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 4 | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| /* | ||
| Copyright 2025. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
|
|
||
| package sdk_tests | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "os" | ||
| "testing" | ||
|
|
||
| . "github.com/onsi/gomega" | ||
|
|
||
| corev1 "k8s.io/api/core/v1" | ||
|
|
||
| common "github.com/opendatahub-io/distributed-workloads/tests/common" | ||
| support "github.com/opendatahub-io/distributed-workloads/tests/common/support" | ||
| trainerutils "github.com/opendatahub-io/distributed-workloads/tests/trainer/utils" | ||
| ) | ||
|
|
||
| const ( | ||
| notebookName = "mnist.ipynb" | ||
| notebookPath = "resources/" + notebookName | ||
| ) | ||
|
|
||
| // CPU Only - Distributed Training | ||
| func RunFashionMnistCpuDistributedTraining(t *testing.T) { | ||
| test := support.With(t) | ||
|
|
||
| // Create a new test namespace | ||
| namespace := test.NewTestNamespace() | ||
|
|
||
| // Ensure pre-requisites to run the test are met | ||
| trainerutils.EnsureTrainerClusterReady(t, test) | ||
|
|
||
| // Ensure Notebook SA and RBACs are set for this namespace | ||
| trainerutils.EnsureNotebookRBAC(t, test, namespace.Name) | ||
|
|
||
| // RBACs setup | ||
| userName := common.GetNotebookUserName(test) | ||
| userToken := common.GetNotebookUserToken(test) | ||
| support.CreateUserRoleBindingWithClusterRole(test, userName, namespace.Name, "admin") | ||
|
|
||
| // Read notebook from directory | ||
| localPath := notebookPath | ||
| nb, err := os.ReadFile(localPath) | ||
| test.Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("failed to read notebook: %s", localPath)) | ||
|
|
||
| // Create ConfigMap with notebook | ||
| cm := support.CreateConfigMap(test, namespace.Name, map[string][]byte{notebookName: nb}) | ||
|
|
||
| // Build command | ||
| marker := "/opt/app-root/src/notebook_completion_marker" | ||
| shellCmd := trainerutils.BuildPapermillShellCmd(notebookName, marker, nil) | ||
| command := []string{"/bin/sh", "-c", shellCmd} | ||
|
|
||
| // Create Notebook CR (with default 10Gi PVC) | ||
| pvc := support.CreatePersistentVolumeClaim(test, namespace.Name, "10Gi", support.AccessModes(corev1.ReadWriteOnce)) | ||
| common.CreateNotebook(test, namespace, userToken, command, cm.Name, notebookName, 0, pvc, common.ContainerSizeSmall) | ||
|
|
||
| // Cleanup | ||
| defer func() { | ||
| common.DeleteNotebook(test, namespace) | ||
| test.Eventually(common.Notebooks(test, namespace), support.TestTimeoutLong).Should(HaveLen(0)) | ||
| }() | ||
|
|
||
| // Wait for the Notebook Pod and get pod/container names | ||
| podName, containerName := trainerutils.WaitForNotebookPodRunning(test, namespace.Name) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it have sense to add assertion checking that TrainJob is created and successfully finished?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think this is very important :) |
||
| // Poll marker file to check if the notebook execution completed successfully | ||
| if err := trainerutils.PollNotebookCompletionMarker(test, namespace.Name, podName, containerName, marker, support.TestTimeoutDouble); err != nil { | ||
| test.Expect(err).To(Succeed(), "Notebook execution reported FAILURE") | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will become an issue for running tests on disconnected clusters.
In Trainer v1 tests we uploaded dataset on AWS S3, it is downloaded from there if AWS env variables are declared - https://github.com/opendatahub-io/distributed-workloads/blob/main/tests/kfto/resources/kfto_sdk_mnist.py#L67