Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
bin/*
.vscode/*
30 changes: 30 additions & 0 deletions tests/trainer/kubeflow_sdk_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
Copyright 2025.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package trainer

import (
"testing"

. "github.com/opendatahub-io/distributed-workloads/tests/common"
sdktests "github.com/opendatahub-io/distributed-workloads/tests/trainer/sdk_tests"
)

func TestKubeflowSDK_Sanity(t *testing.T) {
Tags(t, Sanity)
sdktests.RunFashionMnistCpuDistributedTraining(t)
// ADD MORE SANITY TESTS HERE
}
275 changes: 275 additions & 0 deletions tests/trainer/resources/mnist.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-03T13:19:46.917723Z",
"iopub.status.busy": "2025-09-03T13:19:46.917308Z",
"iopub.status.idle": "2025-09-03T13:19:46.935181Z",
"shell.execute_reply": "2025-09-03T13:19:46.934697Z",
"shell.execute_reply.started": "2025-09-03T13:19:46.917698Z"
}
},
"outputs": [],
"source": [
"def train_fashion_mnist():\n",
" import os\n",
"\n",
" import torch\n",
" import torch.distributed as dist\n",
" import torch.nn.functional as F\n",
" from torch import nn\n",
" from torch.utils.data import DataLoader, DistributedSampler\n",
" from torchvision import datasets, transforms\n",
"\n",
" # Define the PyTorch CNN model to be trained\n",
" class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
" self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
" self.fc1 = nn.Linear(4 * 4 * 50, 500)\n",
" self.fc2 = nn.Linear(500, 10)\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(self.conv1(x))\n",
" x = F.max_pool2d(x, 2, 2)\n",
" x = F.relu(self.conv2(x))\n",
" x = F.max_pool2d(x, 2, 2)\n",
" x = x.view(-1, 4 * 4 * 50)\n",
" x = F.relu(self.fc1(x))\n",
" x = self.fc2(x)\n",
" return F.log_softmax(x, dim=1)\n",
"\n",
" # Use NCCL if a GPU is available, otherwise use Gloo as communication backend.\n",
" device, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n",
" print(f\"Using Device: {device}, Backend: {backend}\")\n",
"\n",
" # Setup PyTorch distributed.\n",
" local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n",
" dist.init_process_group(backend=backend)\n",
" print(\n",
" \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n",
" dist.get_world_size(),\n",
" dist.get_rank(),\n",
" local_rank,\n",
" )\n",
" )\n",
"\n",
" # Create the model and load it into the device.\n",
" device = torch.device(f\"{device}:{local_rank}\")\n",
" model = nn.parallel.DistributedDataParallel(Net().to(device))\n",
" optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n",
"\n",
" \n",
" # Download FashionMNIST dataset only on local_rank=0 process.\n",
" if local_rank == 0:\n",
" dataset = datasets.FashionMNIST(\n",
" \"./data\",\n",
" train=True,\n",
" download=True,\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will become an issue for running tests on disconnected clusters.
In Trainer v1 tests we uploaded dataset on AWS S3, it is downloaded from there if AWS env variables are declared - https://github.com/opendatahub-io/distributed-workloads/blob/main/tests/kfto/resources/kfto_sdk_mnist.py#L67

" transform=transforms.Compose([transforms.ToTensor()]),\n",
" )\n",
" dist.barrier()\n",
" dataset = datasets.FashionMNIST(\n",
" \"./data\",\n",
" train=True,\n",
" download=False,\n",
" transform=transforms.Compose([transforms.ToTensor()]),\n",
" )\n",
"\n",
"\n",
" # Shard the dataset accross workers.\n",
" train_loader = DataLoader(\n",
" dataset,\n",
" batch_size=100,\n",
" sampler=DistributedSampler(dataset)\n",
" )\n",
"\n",
" # TODO(astefanutti): add parameters to the training function\n",
" dist.barrier()\n",
" for epoch in range(1, 3):\n",
" model.train()\n",
"\n",
" # Iterate over mini-batches from the training set\n",
" for batch_idx, (inputs, labels) in enumerate(train_loader):\n",
" # Copy the data to the GPU device if available\n",
" inputs, labels = inputs.to(device), labels.to(device)\n",
" # Forward pass\n",
" outputs = model(inputs)\n",
" loss = F.nll_loss(outputs, labels)\n",
" # Backward pass\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if batch_idx % 10 == 0 and dist.get_rank() == 0:\n",
" print(\n",
" \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n",
" epoch,\n",
" batch_idx * len(inputs),\n",
" len(train_loader.dataset),\n",
" 100.0 * batch_idx / len(train_loader),\n",
" loss.item(),\n",
" )\n",
" )\n",
"\n",
" # Wait for the distributed training to complete\n",
" dist.barrier()\n",
" if dist.get_rank() == 0:\n",
" print(\"Training is finished\")\n",
"\n",
" # Finally clean up PyTorch distributed\n",
" dist.destroy_process_group()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-03T13:19:49.832393Z",
"iopub.status.busy": "2025-09-03T13:19:49.832117Z",
"iopub.status.idle": "2025-09-03T13:19:51.924613Z",
"shell.execute_reply": "2025-09-03T13:19:51.924264Z",
"shell.execute_reply.started": "2025-09-03T13:19:49.832371Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from kubeflow.trainer import CustomTrainer, TrainerClient\n",
"\n",
"client = TrainerClient()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for runtime in client.list_runtimes():\n",
" print(runtime)\n",
" if runtime.name == \"universal\": # Update to actual universal image runtime once available\n",
" torch_runtime = runtime"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this approach instead of getting universal runtime by client.get_runtime("universal")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That notebook is basically a copy from usptream test that SDK upstream relies on. I wanted to keep it as close as possible to original but I guess there's no harm in moving to get_runtime. Thanks Karel!

]
Comment on lines +156 to +160
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Hardcoded runtime name needs to be resolved.

The comment on line 158 indicates this uses a placeholder runtime name "universal" that needs updating. This conflicts with the hardcoded runtime checks in utils_cluster_prep.go (lines 53-54) which expect different runtime names.

The PR description mentions needing a custom clusterTrainingRuntime named "universal", but the cluster prep utility checks for different names. Ensure consistency between:

  1. The runtime name expected by this notebook (line 158)
  2. The runtimes validated in EnsureTrainerClusterReady (utils_cluster_prep.go:53)
  3. The actual runtime available in the test environment

Consider parameterizing the runtime name or documenting the expected setup more clearly.



---

🏁 Script executed:

```shell
#!/bin/bash
# Extract and examine the notebook file structure and content around lines 156-160
fd -e ipynb -path '*/tests/trainer/resources/mnist.ipynb' --exec cat {} \; | python3 << 'EOF'
import json
import sys

try:
    data = json.load(sys.stdin)
    cells = data.get('cells', [])
    
    # Flatten all cells and count lines
    line_count = 0
    for cell_idx, cell in enumerate(cells):
        source = cell.get('source', [])
        if isinstance(source, str):
            source = source.split('\n')
        
        for line_idx, line in enumerate(source):
            line_count += 1
            # Print lines 150-165 to capture context around 156-160
            if 150 <= line_count <= 165:
                print(f"Line {line_count}: {line}")
except Exception as e:
    print(f"Error: {e}", file=sys.stderr)
EOF

Length of output: 217


🏁 Script executed:

#!/bin/bash
# Examine utils_cluster_prep.go around lines 53-54
fd -e go -name 'utils_cluster_prep.go' --exec cat {} \; | head -70 | tail -25

Length of output: 230


🏁 Script executed:

#!/bin/bash
# Search for runtime-related configurations and definitions
rg -n "runtime" -i --type go -A2 -B2 | head -100

Length of output: 7893



---

🏁 Script executed:

```shell
#!/bin/bash
# Find and read the notebook file directly
find . -name "mnist.ipynb" -path "*/tests/trainer/resources/*" -type f

Length of output: 37


🏁 Script executed:

#!/bin/bash
# Once found, extract notebook content and look for "universal" runtime reference
find . -name "mnist.ipynb" -path "*/tests/trainer/resources/*" -type f -exec python3 << 'EOF' {} \;
import json
import sys

try:
    with open(sys.argv[1], 'r') as f:
        notebook = json.load(f)
    
    # Search through all cells for "universal" string
    for cell_idx, cell in enumerate(notebook.get('cells', [])):
        source = cell.get('source', [])
        if isinstance(source, str):
            source_lines = source.split('\n')
        else:
            source_lines = source
        
        for line_idx, line in enumerate(source_lines):
            if 'universal' in line.lower() or 'torch_runtime' in line.lower():
                print(f"Cell {cell_idx}, Line {line_idx}: {line}")
except Exception as e:
    print(f"Error processing {sys.argv[1]}: {e}", file=sys.stderr)
EOF

Length of output: 218


🏁 Script executed:

#!/bin/bash
# Also search codebase for references to "universal" runtime
rg -n "universal" --type go

Length of output: 113



---

🏁 Script executed:

```shell
#!/bin/bash
# Search notebook file for "universal" using grep/rg
rg -n "universal" ./tests/trainer/resources/mnist.ipynb

Length of output: 108


🏁 Script executed:

#!/bin/bash
# Search for torch_runtime usage in notebook
rg -n "torch_runtime\|runtime" ./tests/trainer/resources/mnist.ipynb | head -30

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Extract notebook with cat and parse JSON properly
cat ./tests/trainer/resources/mnist.ipynb | python3 << 'PYTHON_EOF'
import json
import sys

try:
    nb = json.load(sys.stdin)
    for idx, cell in enumerate(nb.get('cells', [])):
        source = cell.get('source', [])
        if isinstance(source, list):
            content = ''.join(source)
        else:
            content = source
        
        # Look for runtime-related content
        if 'universal' in content.lower() or 'torch_runtime' in content.lower():
            lines = content.split('\n')
            for line_num, line in enumerate(lines):
                if 'universal' in line.lower() or 'torch_runtime' in line.lower():
                    print(f"Cell {idx}: {line}")
except Exception as e:
    print(f"Error: {e}", file=sys.stderr)
PYTHON_EOF

Length of output: 48


Update notebook runtime selection to match validated test runtimes.

The notebook hardcodes a check for runtime.name == "universal" at line 158, but the test utility (utils_cluster_prep.go lines 53-54) validates only torch-cuda-241, torch-cuda-251, torch-rocm-241, and torch-rocm-251 runtimes are available. The "universal" runtime is marked as a future addition (see the TODO at utils_cluster_prep.go line 52). Since "universal" does not currently exist in the test environment, the notebook's torch_runtime variable will remain unset, causing failures in subsequent code.

Align the notebook to use one of the currently available runtimes (e.g., "torch-cuda-241") or parameterize it with a fallback mechanism.

🤖 Prompt for AI Agents
In tests/trainer/resources/mnist.ipynb around lines 156 to 160, the notebook
checks for runtime.name == "universal" which doesn't exist in the test
environment and leaves torch_runtime unset; change the selection to match
validated runtimes (for example "torch-cuda-241") or implement a fallback loop
that picks the first available runtime from the allowed set
["torch-cuda-241","torch-cuda-251","torch-rocm-241","torch-rocm-251"]; update
the condition to test membership in that list and set torch_runtime accordingly
so downstream cells always have a valid runtime.

},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-03T13:19:56.525591Z",
"iopub.status.busy": "2025-09-03T13:19:56.524936Z",
"iopub.status.idle": "2025-09-03T13:19:56.721404Z",
"shell.execute_reply": "2025-09-03T13:19:56.720565Z",
"shell.execute_reply.started": "2025-09-03T13:19:56.525536Z"
}
},
"outputs": [],
"source": [
"job_name = client.train(\n",
" trainer=CustomTrainer(\n",
" func=train_fashion_mnist,\n",
" num_nodes=2,\n",
" resources_per_node={\n",
" \"cpu\": 2,\n",
" \"memory\": \"8Gi\",\n",
" },\n",
" packages_to_install=[\"torchvision\"],\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to install specific version, to make sure that a future upgrade doesn't break test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think it makes sense. Thanks.

" ),\n",
" runtime=torch_runtime,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-03T13:20:01.378158Z",
"iopub.status.busy": "2025-09-03T13:20:01.377707Z",
"iopub.status.idle": "2025-09-03T13:20:12.713960Z",
"shell.execute_reply": "2025-09-03T13:20:12.713295Z",
"shell.execute_reply.started": "2025-09-03T13:20:01.378130Z"
}
},
"outputs": [],
"source": [
"# Wait for the running status.\n",
"client.wait_for_job_status(name=job_name, status={\"Running\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-03T13:20:24.045774Z",
"iopub.status.busy": "2025-09-03T13:20:24.045480Z",
"iopub.status.idle": "2025-09-03T13:20:24.772877Z",
"shell.execute_reply": "2025-09-03T13:20:24.772178Z",
"shell.execute_reply.started": "2025-09-03T13:20:24.045755Z"
}
},
"outputs": [],
"source": [
"for c in client.get_job(name=job_name).steps:\n",
" print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2025-09-03T13:20:26.729486Z",
"iopub.status.busy": "2025-09-03T13:20:26.728951Z",
"iopub.status.idle": "2025-09-03T13:20:29.596510Z",
"shell.execute_reply": "2025-09-03T13:20:29.594741Z",
"shell.execute_reply.started": "2025-09-03T13:20:26.729446Z"
}
},
"outputs": [],
"source": [
"for logline in client.get_job_logs(job_name, follow=True):\n",
" print(logline)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client.delete_job(job_name)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
86 changes: 86 additions & 0 deletions tests/trainer/sdk_tests/fashion_mnist_tests.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
Copyright 2025.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package sdk_tests

import (
"fmt"
"os"
"testing"

. "github.com/onsi/gomega"

corev1 "k8s.io/api/core/v1"

common "github.com/opendatahub-io/distributed-workloads/tests/common"
support "github.com/opendatahub-io/distributed-workloads/tests/common/support"
trainerutils "github.com/opendatahub-io/distributed-workloads/tests/trainer/utils"
)

const (
notebookName = "mnist.ipynb"
notebookPath = "resources/" + notebookName
)

// CPU Only - Distributed Training
func RunFashionMnistCpuDistributedTraining(t *testing.T) {
test := support.With(t)

// Create a new test namespace
namespace := test.NewTestNamespace()

// Ensure pre-requisites to run the test are met
trainerutils.EnsureTrainerClusterReady(t, test)

// Ensure Notebook SA and RBACs are set for this namespace
trainerutils.EnsureNotebookRBAC(t, test, namespace.Name)

// RBACs setup
userName := common.GetNotebookUserName(test)
userToken := common.GetNotebookUserToken(test)
support.CreateUserRoleBindingWithClusterRole(test, userName, namespace.Name, "admin")

// Read notebook from directory
localPath := notebookPath
nb, err := os.ReadFile(localPath)
test.Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("failed to read notebook: %s", localPath))

// Create ConfigMap with notebook
cm := support.CreateConfigMap(test, namespace.Name, map[string][]byte{notebookName: nb})

// Build command
marker := "/opt/app-root/src/notebook_completion_marker"
shellCmd := trainerutils.BuildPapermillShellCmd(notebookName, marker, nil)
command := []string{"/bin/sh", "-c", shellCmd}

// Create Notebook CR (with default 10Gi PVC)
pvc := support.CreatePersistentVolumeClaim(test, namespace.Name, "10Gi", support.AccessModes(corev1.ReadWriteOnce))
common.CreateNotebook(test, namespace, userToken, command, cm.Name, notebookName, 0, pvc, common.ContainerSizeSmall)

// Cleanup
defer func() {
common.DeleteNotebook(test, namespace)
test.Eventually(common.Notebooks(test, namespace), support.TestTimeoutLong).Should(HaveLen(0))
}()

// Wait for the Notebook Pod and get pod/container names
podName, containerName := trainerutils.WaitForNotebookPodRunning(test, namespace.Name)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it have sense to add assertion checking that TrainJob is created and successfully finished?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is very important :)
Thanks

// Poll marker file to check if the notebook execution completed successfully
if err := trainerutils.PollNotebookCompletionMarker(test, namespace.Name, podName, containerName, marker, support.TestTimeoutDouble); err != nil {
test.Expect(err).To(Succeed(), "Notebook execution reported FAILURE")
}
}
Loading
Loading