Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/accelerated-moe/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ ignored-parents=
max-args=5

# Maximum number of attributes for a class (see R0902).
max-attributes=7
max-attributes=8

# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def augmentation(
modifiable_args: Tuple[LoraConfig],
):
rank, world_size = 0, 1
(peft_config,) = modifiable_args
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
# we do not need to use the fallback as this is wrapped in an `is_initialized` block
Expand All @@ -97,6 +98,7 @@ def augmentation(
ep_degree=self._ep_degree,
disable_distributed=self._disable_distributed,
mixed_precision=False, # Currently this is hardcoded to OFF
lora_config=peft_config,
)
return model, modifiable_args

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
KEY_MODEL = "model"
KEY_OPTIMIZER = "optimizer"

ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"

# Below are rewrite of HF FSDP model saving functions to be able to handle
# that the parameters are now a mixture of regular and Dtensors.
# - these functions are found in accelerate.utils.fsdp_utils.py
Expand Down Expand Up @@ -110,16 +112,30 @@ def save_fsdp_optimizer(
# get the state dicts for model and optimize
(model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer)

# filter out lora state dict
lora_state_dict = {
k: v for k, v in model_state_dict.items() if "lora_A" in k or "lora_B" in k
}

# - save model
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
os.makedirs(ckpt_model, exist_ok=True)
logger.info(f"Saving model to {ckpt_model}")
dcp.save(
state_dict={KEY_MODEL: model_state_dict},
storage_writer=dcp.FileSystemWriter(ckpt_model),
planner=DefaultSavePlanner(),
)
logger.info(f"Model saved to {ckpt_model}")
if lora_state_dict:
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
os.makedirs(ckpt_model, exist_ok=True)
logger.info(f"Saving lora model to {ckpt_model}")
dcp.save(
state_dict={KEY_MODEL: lora_state_dict},
storage_writer=dcp.FileSystemWriter(ckpt_model),
planner=DefaultSavePlanner(),
)
else:
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
os.makedirs(ckpt_model, exist_ok=True)
logger.info(f"Saving ft model to {ckpt_model}")
dcp.save(
state_dict={KEY_MODEL: model_state_dict},
storage_writer=dcp.FileSystemWriter(ckpt_model),
planner=DefaultSavePlanner(),
)

# - save optimizer
ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
Expand Down Expand Up @@ -467,30 +483,54 @@ def save_sharded_safetensors(
save_directory: str,
metadata: Dict,
max_shard_size: Union[int, str] = "5GB",
lora: bool = False,
):
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
if not lora:
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {
tensor: input_state_dict[tensor].contiguous() for tensor in tensors
}
save_file(
shard, os.path.join(save_directory, shard_file), metadata=metadata
)
else:
filename_pattern = ADAPTER_SAFE_WEIGHTS_NAME.replace(
".bin", "{suffix}.bin"
).replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
input_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size,
)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {
tensor: input_state_dict[tensor].contiguous() for tensor in tensors
}
save_file(
shard, os.path.join(save_directory, shard_file), metadata=metadata
)


# --------------------------- SCRIPT -------------------------
Expand Down Expand Up @@ -540,14 +580,32 @@ def recover_safetensors_from_dcp(
# get the state_dict
state_dict = loader(checkpoint_dir)

# filter out additional names created by lora tuning
# create switch based on state dict for future use
new_state_dict = {}
lora = False
for name, param in state_dict.items():
# if lora weight, set lora switch to true
if "lora_A" in name or "lora_B" in name:
lora = True
# if lora naming convention, convert to traditional
if "base_model.model." in name:
name = name.replace("base_model.model.", "", 1)
if "default." in name:
name = name.replace("default.", "", 1)
new_state_dict[name] = param

# recover the original state dict
state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path)
state_dict = recover_original_state_dict_from_checkpoint(
new_state_dict, _name_or_path
)

# save it as a safetensors file
save_sharded_safetensors(
{k: v.contiguous() for k, v in state_dict.items()},
output_dir,
metadata={"format": "pt"},
lora=lora,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

# Third Party
from peft import LoraConfig
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from torch.distributed._tensor import DTensor

# pylint: disable=import-error
Expand Down Expand Up @@ -237,10 +236,6 @@ def __init__(
assert (
lora_config.bias == "none"
), "ScatterMoE currently unable to handle bias in the lora adapters"
assert (
lora_config.target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND
or INCLUDE_LINEAR_LAYERS_SHORTHAND in lora_config.target_modules
), "ScatterMoe currently only handles lora adapters on all linears."

assert lora_config.init_lora_weights in {
True,
Expand Down Expand Up @@ -286,7 +281,6 @@ def __init__(
grouped_out=True,
dtype=dtype,
device=device,
lora_config=lora_config,
)
self.w2 = ScatteredExperts(
in_features=self.intermediate_size,
Expand All @@ -296,7 +290,6 @@ def __init__(
grouped_in=True,
dtype=dtype,
device=device,
lora_config=lora_config,
)
if mlp_arch == SCATTERMOE_SPEC_HAS_GATE:
self.w3 = ScatteredExperts(
Expand All @@ -307,7 +300,6 @@ def __init__(
grouped_out=True,
dtype=dtype,
device=device,
lora_config=lora_config,
)

# referenced from dolomite-engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def _hook(grad):

# install gradient scaling hook
if KEY_SCATTERMOE_ROUTER not in weight_name:
param.register_hook(_hook)
if param.requires_grad:
param.register_hook(_hook)

# register the sharded parameter onto the megablocks.dmoe
mod.register_parameter(name, param)
Expand Down