Skip to content

Commit 3a9d3fd

Browse files
authored
Added conditional GPU accelerated ODM (#167)
* Added kmeans from cuml Signed-off-by: romit <[email protected]> * Fixed ci cd Signed-off-by: romit <[email protected]> --------- Signed-off-by: romit <[email protected]>
1 parent c9adae5 commit 3a9d3fd

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

plugins/online-data-mixing/artifacts/custom_loop_usage.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,18 @@
1919
from fms_acceleration_odm import OnlineMixingDataset
2020
from fms_acceleration_odm.odm.reward import Reward
2121

22-
model_name = "ibm-granite/granite-4.0-h-1b"
22+
model_name = "ibm-granite/granite-4.0-350m"
2323
output_dir = "./odm_custom_use"
2424
max_steps = 125
2525
batch_size = 4
2626
log_file = os.path.join(output_dir, "loss.jsonl")
2727

2828
# odm related
2929
step_idx = 0
30-
update_interval = 1 # every step
30+
update_interval = 10 # every 10 steps
3131

3232
# model
33-
model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16)
33+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
3434

3535
# tokenizer
3636
tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -102,7 +102,7 @@ def collate_fn(batch, tokenizer):
102102
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=None)
103103

104104
# distributed setup
105-
dataloader_config = DataLoaderConfiguration(split_batches=True, dispatch_batches=True)
105+
dataloader_config = DataLoaderConfiguration(dispatch_batches=False)
106106
accelerator = Accelerator(dataloader_config=dataloader_config)
107107
model, dataloader = accelerator.prepare(model, dataloader)
108108

@@ -141,7 +141,7 @@ class State:
141141
if step_idx % update_interval == 0:
142142
with torch.no_grad():
143143
model.eval()
144-
dataloader.dataset.update_sampling_weights(model, accelerator, state)
144+
dataset.update_sampling_weights(model, accelerator, state)
145145
model.train()
146146
if step_idx > max_steps:
147147
break

plugins/online-data-mixing/pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@ license = {text = "Apache-2.0"}
1515
readme = "README.md"
1616
requires-python = "~=3.11"
1717
keywords = ['fms-hf-tuning', 'acceleration', 'online-data-mixing']
18-
classifiers=[
18+
classifiers = [
1919
"License :: OSI Approved :: Apache Software License",
2020
"Development Status :: 4 - Beta",
2121
"Programming Language :: Python :: 3",
2222
"Programming Language :: Python :: 3.11",
2323
]
2424

2525
dependencies = [
26+
"torch==2.8.0",
27+
"torchvision==0.23.0",
28+
"torchaudio==2.8.0",
2629
"scikit-learn",
2730
"datasets==4.*",
2831
"torchdata==0.11.0",
@@ -43,4 +46,4 @@ include = [
4346
"src",
4447
"pyproject.toml",
4548
"README.md",
46-
]
49+
]

plugins/online-data-mixing/src/fms_acceleration_odm/odm/auto_categorizer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
# Third Party
2828
from datasets import Dataset, DatasetDict
2929
from sentence_transformers import SentenceTransformer
30-
from sklearn.cluster import KMeans
3130
import numpy as np
3231
import torch
3332

@@ -175,6 +174,14 @@ def _cluster_embeddings(
175174
"Unsupported clustering algorithm '%s'. Only 'kmeans' is currently supported."
176175
% self.config.cluster_algo
177176
)
177+
178+
try:
179+
from cuml import KMeans # pylint: disable=import-outside-toplevel
180+
print("Using GPU accelerated Kmeans")
181+
except ImportError:
182+
print("GPU accelerated KMeans is not avaialble. Falling back to CPU based KMeans")
183+
from sklearn.cluster import KMeans # pylint: disable=import-outside-toplevel
184+
178185
kwargs = {"n_init": 10}
179186
kwargs.update(self.config.cluster_kwargs)
180187
model = KMeans(n_clusters=num_categories, **kwargs)

0 commit comments

Comments
 (0)