Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/mini_trainer/api_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:

if train_args.save_dtype:
command.append(f"--save-dtype={train_args.save_dtype}")


command.append(f"--device={train_args.device}")
if train_args.torch_compile:
command.append("--torch-compile")
command.append(f"--num-chunks={train_args.num_chunks}")

logger.info("Running training command as subprocess: %s", " ".join(command))

# Run the training process
Expand Down
223 changes: 125 additions & 98 deletions src/mini_trainer/osft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import torch.distributed as dist
import typing as t
from typing import Protocol
import multiprocessing

from tqdm import tqdm

from mini_trainer.utils import log_rank_0, check_distributed_is_synchronized
from mini_trainer.gpt_oss_utils import is_gpt_oss_model
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM

import os

Expand Down Expand Up @@ -291,6 +291,11 @@ def cast_to_osft_model(model: torch.nn.Module) -> OSFTModel:
raise TypeError(f"Model {type(model)} does not implement OSFT interface")
return model # type: ignore


def create_svd_dict_star(args_tuple):
return create_svd_dict(*args_tuple)


def create_svd_dict(
weight: torch.Tensor,
top_k: int,
Expand Down Expand Up @@ -503,7 +508,7 @@ def partition_svd_computation(target_params, world_size):
return assignments


def _broadcast_tensor_device_aware(tensor, src_rank):
def _broadcast_tensor_device_aware(tensor, src_rank, device):
"""
Broadcasts a tensor that might be on CPU by temporarily moving to GPU if needed.

Expand All @@ -514,9 +519,7 @@ def _broadcast_tensor_device_aware(tensor, src_rank):
# If tensor is on CPU, we need to move to GPU for NCCL broadcasting
if tensor.device.type == 'cpu':
current_rank = dist.get_rank()
local_rank = int(os.environ.get('LOCAL_RANK', 0))
gpu_device = torch.device('cuda', local_rank)
tensor_gpu = tensor.to(gpu_device)
tensor_gpu = tensor.to(device)

# Broadcast the GPU tensor
dist.broadcast(tensor_gpu, src=src_rank)
Expand Down Expand Up @@ -547,6 +550,7 @@ def broadcast_svd_results(model, assignments, world_size):
param_to_rank[name] = rank

# Broadcast each parameter from its computing rank
device = next(model.parameters()).device
for name in param_to_rank:
src_rank = param_to_rank[name]
# Locate the owning module for this parameter
Expand All @@ -558,7 +562,7 @@ def broadcast_svd_results(model, assignments, world_size):
for tensor_name in ("osft_U_high", "osft_S_high", "osft_V_high"):
if hasattr(owner_mod, tensor_name):
tensor = getattr(owner_mod, tensor_name)
_broadcast_tensor_device_aware(tensor, src_rank)
_broadcast_tensor_device_aware(tensor, src_rank, device)
else:
raise AttributeError(f"OSFT broadcast: {tensor_name} not found on module for {name}")

Expand All @@ -568,9 +572,9 @@ def broadcast_svd_results(model, assignments, world_size):
# Broadcast trainable low-rank components (on module.osft_params)
if hasattr(owner_mod, "osft_params"):
svd_module = owner_mod.osft_params
_broadcast_tensor_device_aware(svd_module.U_low, src_rank)
_broadcast_tensor_device_aware(svd_module.S_low, src_rank)
_broadcast_tensor_device_aware(svd_module.V_low, src_rank)
_broadcast_tensor_device_aware(svd_module.U_low, src_rank, device)
_broadcast_tensor_device_aware(svd_module.S_low, src_rank, device)
_broadcast_tensor_device_aware(svd_module.V_low, src_rank, device)
else:
raise AttributeError(f"OSFT broadcast: osft_params not found on module for {name}")

Expand Down Expand Up @@ -857,7 +861,7 @@ def _initialize_osft_with_distribution(model):
Returns:
The initialized model
"""

if not dist.is_initialized() or dist.get_world_size() == 1:
# Simple cases: non-distributed or single process
model.reinitialize_osft(decompose_existing_weights=True)
Expand Down Expand Up @@ -965,6 +969,7 @@ def from_pretrained(
if is_gpt_oss:
base_kwargs = _filter_osft_parameters(filtered_kwargs, OSFT_GPT_OSS_FILTERED_PARAMS)
# For GPT-OSS, we need to use the specific model class
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM
actual_osft_cls = create_osft_model_class(GptOssForCausalLM)
else:
base_kwargs = filtered_kwargs.copy()
Expand Down Expand Up @@ -1110,7 +1115,7 @@ def forward(x):
mod.forward = make_forward(mod, bias)
param.requires_grad = False
mod._parameters.pop(attr, None)

# Step 4: Each rank computes SVD for its assigned parameters
assigned_params = assignments[global_rank]
if assigned_params:
Expand All @@ -1122,7 +1127,8 @@ def forward(x):
# It's possible for processes to be dysynchronized by this point due to the
# uneven split of work when processing model layers in parallel.
# So here we ensure that everything is synchronized before proceeding.
check_distributed_is_synchronized()
device = next(self.parameters()).device
check_distributed_is_synchronized(device)
broadcast_svd_results(self, assignments, world_size)
torch.distributed.barrier()

Expand All @@ -1142,6 +1148,48 @@ def _get_module_by_name(self, name):
return None, None
return mod, attr

def _save_svd_dict_to_module(self, name, param, svd_dict):
safe_name = name.replace(".", "_")
self.name_mapping[name] = safe_name

# Attach OSFT components to the owning module so only block-local params materialize
mod, attr = self._get_module_by_name(name)
# High-rank frozen components
mod.register_parameter("osft_U_high", nn.Parameter(svd_dict["U_high"], requires_grad=False))
mod.register_parameter("osft_S_high", nn.Parameter(svd_dict["S_high"], requires_grad=False))
mod.register_parameter("osft_V_high", nn.Parameter(svd_dict["V_high"], requires_grad=False))
# Trainable low-rank components
module_svd = nn.Module()
module_svd.U_low = svd_dict["U_low"]
module_svd.S_low = svd_dict["S_low"]
module_svd.V_low = svd_dict["V_low"]
module_svd.rank_high = svd_dict["rank_high"]
module_svd.safe_name = safe_name
mod.add_module("osft_params", module_svd)

bias = mod.bias if hasattr(mod, "bias") else None

# Override linear projection to use module-local OSFT params
def make_forward(owner_mod, bias):
def forward(x):
svd_dict = {
"U_high": owner_mod.osft_U_high,
"S_high": owner_mod.osft_S_high,
"V_high": owner_mod.osft_V_high,
"U_low": owner_mod.osft_params.U_low,
"S_low": owner_mod.osft_params.S_low,
"V_low": owner_mod.osft_params.V_low,
"rank_high": owner_mod.osft_params.rank_high,
}
return self._factorized_linear(x, svd_dict, bias)
return forward

mod.forward = make_forward(mod, bias)
param.requires_grad = False
# Remove original parameter so it doesn't get updated
mod._parameters.pop(attr, None)
torch.cuda.empty_cache()

def _initialize_osft_parameters(
self, decompose_existing_weights: bool, assigned_params=None
):
Expand Down Expand Up @@ -1187,97 +1235,76 @@ def _initialize_osft_parameters(
log_rank_0("\033[33m!!!! Initializing OSFT Params!!!!\033[0m")

# Set up target device for memory-efficient operations
local_rank = int(os.getenv("LOCAL_RANK", 0))
target_device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu")
target_device = next(self.parameters()).device

if self.osft_memory_efficient_init:
log_rank_0(f"🧠 Using ultra-memory-efficient OSFT initialization for device {target_device}")

for name, param in named_params:
# Apply SVD only to 2D matrices in the target config (e.g., q_proj, down_proj, etc.)
if is_osft_param(name, param, self.osft_config):
top_k = self.osft_config[name]

if self.osft_memory_efficient_init:
# Memory monitoring before processing
if torch.cuda.is_available():
mem_before = torch.cuda.memory_allocated(target_device) / 1e9
log_rank_0(f"🔄 Processing {name} with incremental GPU usage (top_k={top_k}) - GPU mem: {mem_before:.2f}GB")
if target_device.type != "hpu":
for name, param in named_params:
# Apply SVD only to 2D matrices in the target config (e.g., q_proj, down_proj, etc.)
if is_osft_param(name, param, self.osft_config):
top_k = self.osft_config[name]

# Memory-efficient processing: move parameter to GPU temporarily for SVD
param_gpu = param.data.to(target_device)

# Perform SVD on GPU
svd_dict = create_svd_dict(
param_gpu,
top_k=top_k,
decompose_existing=decompose_existing_weights,
upcast_dtype=self.upcast_dtype,
output_dtype=self.output_dtype,
)

# Move SVD components to target device and clear GPU cache
for key in svd_dict:
if isinstance(svd_dict[key], torch.Tensor):
svd_dict[key] = svd_dict[key].to(target_device)

# Clear the temporary GPU tensor
del param_gpu
torch.cuda.empty_cache()

# Memory monitoring after processing
if torch.cuda.is_available():
mem_after = torch.cuda.memory_allocated(target_device) / 1e9
log_rank_0(f"✅ Completed {name} - GPU mem: {mem_after:.2f}GB (freed: {mem_before-mem_after:.2f}GB)")
else:
# Standard processing
svd_dict = create_svd_dict(
param.data,
top_k=top_k,
decompose_existing=decompose_existing_weights,
upcast_dtype=self.upcast_dtype,
output_dtype=self.output_dtype,
)
safe_name = name.replace(".", "_")
self.name_mapping[name] = safe_name

# Attach OSFT components to the owning module so only block-local params materialize
mod, attr = self._get_module_by_name(name)
# High-rank frozen components
mod.register_parameter("osft_U_high", nn.Parameter(svd_dict["U_high"], requires_grad=False))
mod.register_parameter("osft_S_high", nn.Parameter(svd_dict["S_high"], requires_grad=False))
mod.register_parameter("osft_V_high", nn.Parameter(svd_dict["V_high"], requires_grad=False))
# Trainable low-rank components
module_svd = nn.Module()
module_svd.U_low = svd_dict["U_low"]
module_svd.S_low = svd_dict["S_low"]
module_svd.V_low = svd_dict["V_low"]
module_svd.rank_high = svd_dict["rank_high"]
module_svd.safe_name = safe_name
mod.add_module("osft_params", module_svd)

bias = mod.bias if hasattr(mod, "bias") else None

# Override linear projection to use module-local OSFT params
def make_forward(owner_mod, bias):
def forward(x):
svd_dict = {
"U_high": owner_mod.osft_U_high,
"S_high": owner_mod.osft_S_high,
"V_high": owner_mod.osft_V_high,
"U_low": owner_mod.osft_params.U_low,
"S_low": owner_mod.osft_params.S_low,
"V_low": owner_mod.osft_params.V_low,
"rank_high": owner_mod.osft_params.rank_high,
}
return self._factorized_linear(x, svd_dict, bias)
return forward

mod.forward = make_forward(mod, bias)
param.requires_grad = False
# Remove original parameter so it doesn't get updated
mod._parameters.pop(attr, None)
torch.cuda.empty_cache()
if self.osft_memory_efficient_init:
# Memory monitoring before processing
if torch.cuda.is_available():
mem_before = torch.cuda.memory_allocated(target_device) / 1e9
log_rank_0(f"🔄 Processing {name} with incremental GPU usage (top_k={top_k}) - GPU mem: {mem_before:.2f}GB")

# Memory-efficient processing: move parameter to GPU temporarily for SVD
param_gpu = param.data.to(target_device)

# Perform SVD on GPU
svd_dict = create_svd_dict(
param_gpu,
top_k=top_k,
decompose_existing=decompose_existing_weights,
upcast_dtype=self.upcast_dtype,
output_dtype=self.output_dtype,
)

# Move SVD components to target device and clear GPU cache
for key in svd_dict:
if isinstance(svd_dict[key], torch.Tensor):
svd_dict[key] = svd_dict[key].to(target_device)

# Clear the temporary GPU tensor
del param_gpu
torch.cuda.empty_cache()

# Memory monitoring after processing
if torch.cuda.is_available():
mem_after = torch.cuda.memory_allocated(target_device) / 1e9
log_rank_0(f"✅ Completed {name} - GPU mem: {mem_after:.2f}GB (freed: {mem_before-mem_after:.2f}GB)")
else:
# Standard processing
svd_dict = create_svd_dict(
param.data,
top_k=top_k,
decompose_existing=decompose_existing_weights,
upcast_dtype=self.upcast_dtype,
output_dtype=self.output_dtype,
)
self._save_svd_dict_to_module(name, param, svd_dict)

else:
osft_params_to_process = []
for name, param in named_params:
if is_osft_param(name, param, self.osft_config):
osft_params_to_process.append((name, param, param.data, self.osft_config[name], decompose_existing_weights, self.upcast_dtype, self.output_dtype))

svd_args = [(param_data.to('cpu'), top_k, decompose_existing, upcast_dtype, output_dtype)
for _, _, param_data, top_k, decompose_existing, upcast_dtype, output_dtype in osft_params_to_process]

os.environ['PT_HPU_AUTOLOAD'] = '0'
num_cpu_per_rank = 8 # TODO add command line parameter
mp_context = multiprocessing.get_context('spawn')
with mp_context.Pool(num_cpu_per_rank) as pool:
res = pool.imap(create_svd_dict_star, svd_args)
for i, svd_dict in enumerate(res):
name, param, _, _, _, _, _ = osft_params_to_process[i]
self._save_svd_dict_to_module(name, param, svd_dict)

# Barrier for synchronization in distributed setting
if dist.is_initialized():
Expand Down
Loading