Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/feature-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ body:
description: |
How can you contribute to this feature? For example, could you help by submitting a PR?
validations:
required: true
required: true
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_pro
## Training Recipes

Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus ~~in streaming mode~~. (Do not use streaming mode if you are concerned about resuming training.)

> [!WARNING]
> If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
> For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
Expand Down
2 changes: 1 addition & 1 deletion configs/delta_net_1B.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
"use_gate": false,
"use_output_norm": true,
"use_short_conv": true
}
}
2 changes: 1 addition & 1 deletion configs/delta_net_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
"use_gate": false,
"use_output_norm": true,
"use_short_conv": true
}
}
2 changes: 1 addition & 1 deletion configs/gated_deltanet_1B.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
"use_cache": true,
"use_gate": true,
"use_short_conv": true
}
}
2 changes: 1 addition & 1 deletion configs/gated_deltanet_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
"use_cache": true,
"use_gate": true,
"use_short_conv": true
}
}
2 changes: 1 addition & 1 deletion configs/gla_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
"use_gk": true,
"use_gv": false,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion configs/gla_7B.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
"use_gv": false,
"use_output_gate": true,
"use_short_conv": false
}
}
2 changes: 1 addition & 1 deletion configs/gsa_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
"use_output_gate": true,
"use_rope": false,
"use_short_conv": false
}
}
2 changes: 1 addition & 1 deletion configs/hgrn2_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"tie_word_embeddings": false,
"use_cache": true,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion configs/mamba2_1B.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@
"use_cache": true,
"use_conv_bias": true,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion configs/mamba2_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@
"use_cache": true,
"use_conv_bias": true,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion configs/mamba_1B.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
"use_cache": true,
"use_conv_bias": true,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion configs/mamba_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
"use_cache": true,
"use_conv_bias": true,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion configs/samba_1B.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@
"use_cache": true,
"use_conv_bias": true,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion configs/transformer_1B.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
"pad_token_id": 2,
"rope_theta": 10000.0,
"tie_word_embeddings": false
}
}
2 changes: 1 addition & 1 deletion configs/transformer_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"tie_word_embeddings": false,
"use_cache": true,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion configs/transformer_7B.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
"tie_word_embeddings": false,
"use_cache": true,
"window_size": null
}
}
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,18 @@ dependencies = [
'tensorboard',
]

[tool.setuptools.dynamic]
version = {attr = "flame.__version__"}

[project.optional-dependencies]
dev = ["pytest"]

[project.urls]
Homepage = "https://github.com/fla-org/flame"

[build-system]
requires = ["setuptools>=45", "wheel", "ninja", "torch"]
requires = ["setuptools>=64"]
build-backend = "setuptools.build_meta"

[tool.isort]
line_length = 127
Expand Down
51 changes: 0 additions & 51 deletions setup.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
3 changes: 1 addition & 2 deletions flame/data.py → src/flame/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from datasets.iterable_dataset import ShufflingConfig
from torch.distributed.checkpoint.stateful import Stateful
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer

from torchtitan.tools import utils
from torchtitan.tools.logging import logger
from transformers import PreTrainedTokenizer


class BufferShuffledIterableDataset(IterableDataset):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch
from torch import nn
from torch.autograd.graph import saved_tensors_hooks

from torchtitan.tools.logging import logger

try:
Expand Down
2 changes: 1 addition & 1 deletion flame/models/fla.toml → src/flame/models/fla.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false

[activation_checkpoint]
mode = "none"
mode = "none"
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

import torch
import torch.nn as nn
from fla.modules.fused_linear_cross_entropy import LinearLossParallel
from fla.modules.mlp import SwiGLULinearParallel
from fla.modules.parallel import PrepareModuleWeight
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
from torch.distributed._composable.replicate import replicate
Expand All @@ -24,10 +27,6 @@
SequenceParallel,
parallelize_module
)

from fla.modules.fused_linear_cross_entropy import LinearLossParallel
from fla.modules.mlp import SwiGLULinearParallel
from fla.modules.parallel import PrepareModuleWeight
from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.tools.logging import logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage
from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class
from transformers import PretrainedConfig

from flame.models.parallelize_fla import get_blocks, get_components_name, get_model
from torchtitan.config_manager import JobConfig
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank
from torchtitan.tools.logging import logger
from transformers import PretrainedConfig

from flame.models.parallelize_fla import get_blocks, get_components_name, get_model

DeviceType = Union[int, str, torch.device]

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion flame/tools/utils.py → src/flame/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple
for m in model.children()
if isinstance(m, nn.Embedding)
)

if hasattr(model_config, "num_heads"):
num_heads = model_config.num_heads
elif hasattr(model_config, "num_attention_heads"):
Expand Down
20 changes: 10 additions & 10 deletions flame/train.py → src/flame/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,11 @@
import time
from datetime import timedelta

import torch
from torch.distributed.elastic.multiprocessing.errors import record
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

import fla # noqa
import torch
from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
from fla.ops.utils import prepare_position_ids
from flame.components.checkpoint import TrainState
from flame.config_manager import JobConfig
from flame.data import build_dataloader, build_dataset
from flame.models.parallelize_fla import parallelize_fla
from flame.models.pipeline_fla import pipeline_fla
from flame.tools.utils import get_nparams_and_flops
from torch.distributed.elastic.multiprocessing.errors import record
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.ft import FTParallelDims, init_ft_manager
from torchtitan.components.loss import build_cross_entropy_loss
Expand All @@ -35,6 +27,14 @@
from torchtitan.tools import utils
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from flame.components.checkpoint import TrainState
from flame.config_manager import JobConfig
from flame.data import build_dataloader, build_dataset
from flame.models.parallelize_fla import parallelize_fla
from flame.models.pipeline_fla import pipeline_fla
from flame.tools.utils import get_nparams_and_flops


def build_tokenizer(job_config: JobConfig) -> AutoTokenizer:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import tempfile
from datetime import timedelta

import fla # noqa
import torch
import torch.serialization
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

import fla # noqa
from torchtitan.tools.logging import init_logger, logger
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


@torch.inference_mode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import argparse
from pathlib import Path

import fla # noqa
import torch
import torch.distributed.checkpoint as DCP
from transformers import AutoModelForCausalLM

import fla # noqa
from torchtitan.tools.logging import init_logger, logger
from transformers import AutoModelForCausalLM


@torch.inference_mode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import argparse
from typing import Any, Dict, List

from torchtitan.tools.logging import init_logger, logger
from transformers import AutoTokenizer, PreTrainedTokenizer

from flame.data import build_dataset
from torchtitan.tools.logging import init_logger, logger


def tokenize(
Expand Down