Skip to content

Commit 59b37eb

Browse files
Merge pull request #2498 from SamuelMarks:pre-commit-running
PiperOrigin-RevId: 832031708
2 parents c528626 + 4c3fd67 commit 59b37eb

File tree

10 files changed

+147
-39
lines changed

10 files changed

+147
-39
lines changed
Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
name: Linter
15+
name: CodeQuality
1616

1717
on:
1818
pull_request:
@@ -28,22 +28,33 @@ concurrency:
2828
cancel-in-progress: true
2929

3030
jobs:
31-
cpu:
32-
name: "CPU tests"
31+
qa:
32+
name: "Static code-quality checkers"
3333
runs-on: ubuntu-latest
34-
strategy:
35-
matrix:
36-
os: [ubuntu-24.04]
37-
python-version: ['3.12']
3834
steps:
3935
- uses: actions/checkout@v5
40-
- name: Set up Python ${{ matrix.python-version }}
41-
uses: actions/setup-python@v5
36+
37+
- name: Install uv and set the Python version
38+
uses: astral-sh/setup-uv@v7
4239
with:
43-
python-version: ${{ matrix.python-version }}
44-
- name: Install Dependencies
45-
run: |
46-
python3 -m pip install --upgrade pip
47-
python3 -m pip install pre-commit
40+
python-version: '3.12'
41+
enable-cache: true
42+
43+
- name: Set venv
44+
run: uv venv --python 3.12 "$GITHUB_WORKSPACE"/venv
45+
46+
- name: Install `pre-commit`
47+
run: . "$GITHUB_WORKSPACE"/venv/bin/activate && uv pip install pre-commit
48+
49+
- name: Cache pre-commit environments
50+
uses: actions/cache@v4
51+
with:
52+
path: ~/.cache/pre-commit
53+
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
54+
4855
- name: Run pre-commit checks on just the files that have changed
49-
run: pre-commit run
56+
run: |
57+
git fetch origin "$GITHUB_BASE_REF":"$GITHUB_BASE_REF"
58+
git branch "$GITHUB_HEAD_REF"
59+
. "$GITHUB_WORKSPACE"/venv/bin/activate
60+
pre-commit run --from-ref "$GITHUB_BASE_REF" --to-ref "$GITHUB_HEAD_REF"

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ repos:
99
args:
1010
- '-w'
1111
- '--skip="*.txt,pylintrc,.*,src/MaxText/assets/*"'
12-
- '-L ND,nd,sems,TE,ROUGE,rouge,astroid,dout'
12+
- '-L ND,nd,sems,TE,ROUGE,rouge,astroid,ags,dout'
1313
- '.'
1414
additional_dependencies:
1515
- tomli

benchmarks/api_server/server_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class CompletionRequest(SamplingParams):
6969
logprobs: Optional[int] = None
7070

7171
@field_validator("logprobs")
72+
@classmethod
7273
def validate_logprobs(cls, v):
7374
if v is not None and v < 0:
7475
raise ValueError("logprobs must be a non-negative integer if provided.")

end_to_end/gpu/te/run_single_node_model_parallel.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ run_and_parse() {
115115
echo "===== Executing ${test}\t${dp}\t${tpsp}\t${fsdp} ====="
116116
eval "$cmd" 2>&1 | tee "$stdout"
117117
# Exclude the warning steps for warning up and last step for tracing
118-
ths=$(grep 'Tokens/s/device:' "$stdout" | sed '1,'"${WARMUP_STEPS}"'d;$d' | awk -F'Tokens/s/device: ' '{print $2}' | awk -F',' '{print $1}')
118+
std=$(grep 'Tokens/s/device:' "$stdout" | sed '1,'"${WARMUP_STEPS}"'d;$d' | awk -F'Tokens/s/device: ' '{print $2}' | awk -F',' '{print $1}')
119119

120-
if [ -z "$ths" ]; then
120+
if [ -z "$std" ]; then
121121
mean="NA"
122122
stddev="NA"
123123
else
124-
mean_stddev=$(echo "$ths" | python3 -c "import sys; import numpy as np
124+
mean_stddev=$(echo "$std" | python3 -c "import sys; import numpy as np
125125
arr = [float(l.strip()) for l in sys.stdin if l.strip()]
126126
if arr:
127127
print(f'{np.mean(arr):.2f}\t{np.std(arr, ddof=1):.2f}')

src/MaxText/examples/sft_train_and_evaluate.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383

8484
from flax import nnx
8585

86-
from MaxText import globals
86+
from MaxText.globals import MAXTEXT_REPO_ROOT
8787
from MaxText import max_logging
8888
from MaxText import max_utils
8989
from MaxText import pyconfig
@@ -122,10 +122,21 @@
122122
)
123123
# Regex to extract the final numerical answer
124124
MATCH_ANSWER = re.compile(rf"{ANSWER_START}.*?([\d\.\,\$]{{1,}})", flags=re.MULTILINE | re.DOTALL)
125-
CHAT_TEMPLATE_PATH = f"{globals.MAXTEXT_REPO_ROOT}/src/MaxText/examples/chat_templates/math_qa.json"
125+
CHAT_TEMPLATE_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "examples", "chat_templates", "math_qa.json")
126126

127127

128128
def get_test_dataset(config, tokenizer):
129+
"""Loads and prepares the test dataset from Hugging Face.
130+
131+
Args:
132+
config: The pyconfig object containing run configurations, including
133+
`hf_access_token`.
134+
tokenizer: The tokenizer for processing the text data.
135+
136+
Returns:
137+
A grain.MapDataset instance for the test split, with prompts and target
138+
answers.
139+
"""
129140
template_config = instruction_data_processing.load_template_from_file(CHAT_TEMPLATE_PATH)
130141
dataset = datasets.load_dataset(
131142
DATASET_NAME,
@@ -159,7 +170,17 @@ def get_test_dataset(config, tokenizer):
159170

160171

161172
def evaluate_model(dataset, vllm_rollout, debug=True):
162-
"""Runs evaluation on the model using vLLM."""
173+
"""Runs evaluation on the model using vLLM.
174+
175+
Args:
176+
dataset: The dataset to evaluate on.
177+
vllm_rollout: The vLLM rollout object for generating responses.
178+
debug: If True, prints debug information for each sample.
179+
180+
Returns:
181+
A dictionary containing evaluation scores: 'correct', 'partially_correct',
182+
and 'correct_format' percentages.
183+
"""
163184
rollout_config = base_rollout.RolloutConfig(
164185
max_tokens_to_generate=MAX_TOKENS_TO_GENERATE,
165186
max_prompt_length=MAX_PROMPT_LENGTH,
@@ -201,12 +222,35 @@ def evaluate_model(dataset, vllm_rollout, debug=True):
201222

202223

203224
def safe_string_to_float(text):
225+
"""Cleans a string to make it safely convertible to a float.
226+
227+
Removes commas, spaces, and dollar signs.
228+
229+
Args:
230+
text: The input string.
231+
232+
Returns:
233+
The cleaned string.
234+
"""
204235
text = text.replace(",", "").replace(" ", "") # converts "2,125" to "2125"
205236
text = text.replace("$", "") # converts "$50" to "50"
206237
return text
207238

208239

209240
def score_response(target, prediction, debug=True):
241+
"""Scores the model's prediction against the target answer.
242+
243+
It checks for exact correctness, partial correctness (within 10%), and
244+
whether the response follows the expected format.
245+
246+
Args:
247+
target: The ground truth answer string.
248+
prediction: The model's generated response string.
249+
debug: If True, prints exceptions during scoring.
250+
251+
Returns:
252+
A tuple of booleans: (is_correct, is_partially_correct, has_correct_format).
253+
"""
210254
is_correct, is_partially_correct, has_correct_format = False, False, False
211255
extracted_response = guess.group(1) if (guess := MATCH_ANSWER.search(prediction)) is not None else ""
212256
extracted_response = safe_string_to_float(extracted_response)
@@ -231,6 +275,17 @@ def score_response(target, prediction, debug=True):
231275

232276

233277
def create_vllm_rollout(config, model, mesh, tokenizer):
278+
"""Creates a vLLM rollout engine for text generation.
279+
280+
Args:
281+
config: The pyconfig object containing run configurations.
282+
model: The NNX model graph.
283+
mesh: The JAX device mesh.
284+
tokenizer: The tokenizer.
285+
286+
Returns:
287+
A VllmRollout instance configured for the model and hardware.
288+
"""
234289
tunix_model = TunixMaxTextAdapter(model)
235290
return VllmRollout(
236291
model=tunix_model,
@@ -245,6 +300,14 @@ def create_vllm_rollout(config, model, mesh, tokenizer):
245300

246301

247302
def get_tokenizer(config):
303+
"""Initializes and returns the tokenizer.
304+
305+
Args:
306+
config: The pyconfig object with `tokenizer_path` and `hf_access_token`.
307+
308+
Returns:
309+
A Hugging Face tokenizer instance.
310+
"""
248311
tokenizer = transformers.AutoTokenizer.from_pretrained(
249312
config.tokenizer_path,
250313
token=config.hf_access_token,
@@ -253,6 +316,11 @@ def get_tokenizer(config):
253316

254317

255318
def train_and_evaluate(config):
319+
"""Orchestrates the pre-train evaluation, SFT, and post-train evaluation.
320+
321+
Args:
322+
config: The pyconfig object containing all run configurations.
323+
"""
256324
tokenizer = get_tokenizer(config)
257325
test_dataset = get_test_dataset(config, tokenizer)
258326
test_dataset = test_dataset[:NUM_TEST_SAMPLES]
@@ -261,16 +329,16 @@ def train_and_evaluate(config):
261329
vllm_rollout = create_vllm_rollout(config, trainer.model, mesh, tokenizer)
262330

263331
# 1. Pre-SFT Evaluation
264-
max_logging.log(f"Running Pre-SFT evaluation...")
332+
max_logging.log("Running Pre-SFT evaluation...")
265333
score = evaluate_model(test_dataset, vllm_rollout)
266334
print("Score for PRE-SFT EVALUATION: ", score)
267335

268336
# 2. SFT Training
269-
max_logging.log(f"Starting SFT training...")
337+
max_logging.log("Starting SFT training...")
270338
trainer = sft_trainer.train_model(config, trainer, mesh)
271339

272340
# 3. Post-SFT Evaluation
273-
max_logging.log(f"Running Post-SFT evaluation...")
341+
max_logging.log("Running Post-SFT evaluation...")
274342
tunix_model = TunixMaxTextAdapter(trainer.model)
275343
state = nnx.state(tunix_model)
276344
vllm_rollout.update_params(state)

src/MaxText/layers/quantizations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -797,11 +797,11 @@ def _wrap(self, f, name=None):
797797
A Flax linen module that wraps the given function.
798798
"""
799799

800-
import transformer_engine.jax as te # pylint: disable=import-outside-toplevel # pytype: disable=import-error
800+
import transformer_engine.jax # pylint: disable=import-outside-toplevel # pytype: disable=import-error
801801

802802
fp8_recipe = self._recipe
803803

804-
class TEWrapper(te.flax.module.TransformerEngineBase):
804+
class TEWrapper(transformer_engine.jax.flax.module.TransformerEngineBase):
805805
"""Wrapper module for TransformerEngine quantization."""
806806

807807
def generate_quantizer_set(self, postfix: str = ""):
@@ -820,14 +820,14 @@ def __call__(self, *args, **kwargs):
820820

821821
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
822822
"""Placeholder for dot_general implementation in subclasses."""
823-
import transformer_engine.jax as te # pylint: disable=import-outside-toplevel # pytype: disable=import-error
823+
import transformer_engine.jax # pylint: disable=import-outside-toplevel # pytype: disable=import-error
824824

825825
def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs):
826826
contracting_dims, batch_dims = dims
827827
assert batch_dims == ((), ()), "Batch dimensions must be empty for TransformerEngine dot."
828828

829829
quantizer_set = generate_quantizer_set()
830-
return te.dense.dense(
830+
return transformer_engine.jax.dense.dense(
831831
x,
832832
kernel,
833833
contracting_dims=contracting_dims,

src/MaxText/max_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,21 +991,24 @@ def get_batch_seq_len_for_mode(config, model_mode):
991991

992992
return batch_size, seq_len
993993

994+
994995
@contextmanager
995996
def maybe_get_transformer_engine_context(config):
996-
""" Runs a transformer engine context engine manager for GPUs only. """
997-
if config.hardware in ['gpu', 'gpu_multiprocess']:
997+
"""Runs a transformer engine context engine manager for GPUs only."""
998+
if config.hardware in ["gpu", "gpu_multiprocess"]:
998999
with transformer_engine_context():
9991000
yield
10001001
else:
10011002
with dummy_context_manager():
10021003
yield
10031004

1005+
10041006
@contextmanager
10051007
def dummy_context_manager():
10061008
"""A context manager that does nothing."""
10071009
yield
10081010

1011+
10091012
@contextmanager
10101013
def transformer_engine_context():
10111014
"""If TransformerEngine is available, this context manager will provide

src/MaxText/train.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# See github.com/google/maxtext/issues/20 for more
2020

2121
from typing import Any, Sequence
22-
from contextlib import contextmanager
2322
import datetime
2423
import functools
2524
import os
@@ -522,14 +521,13 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]
522521
def run(config, recorder, diagnostic_config):
523522
"""Run the job given hyperparameters and utilities"""
524523
with (
525-
diagnostic.diagnose(diagnostic_config),
526-
maybe_record_goodput(recorder, GoodputEvent.JOB),
527-
max_utils.maybe_get_transformer_engine_context(config)
524+
diagnostic.diagnose(diagnostic_config),
525+
maybe_record_goodput(recorder, GoodputEvent.JOB),
526+
max_utils.maybe_get_transformer_engine_context(config),
528527
):
529528
train_loop(config, recorder)
530529

531530

532-
533531
def main(argv: Sequence[str]) -> None:
534532
config, recorder, diagnostic_config = initialize(argv)
535533
run(config, recorder, diagnostic_config)

src/MaxText/utils/ckpt_scripts/dequantize_mxfp4.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
Example cmd:
1818
1919
python3 -m MaxText.utils.ckpt_scripts.dequantize_mxfp4 --input-path=<input_path> --output-path=<output_path>
20-
python3 -m MaxText.utils.ckpt_scripts.dequantize_mxfp4 --input-path=<input_path> --output-path=<output_path> --dtype-str=bf16 --cache-size=2
20+
python3 -m MaxText.utils.ckpt_scripts.dequantize_mxfp4 --input-path=<input_path> --output-path=<output_path> \
21+
--dtype-str=bf16 --cache-size=2
2122
"""
2223

2324
import os

tests/integration_tests/checkpointing_test.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@
3434

3535

3636
def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention_type, dataset_type, dataset_path):
37+
"""Generates a command list for a checkpointing test run.
38+
39+
Args:
40+
run_date: The date of the run.
41+
hardware: The hardware to run on.
42+
steps: The number of steps to run.
43+
metrics_file: The file to write metrics to.
44+
attention_type: The type of attention to use.
45+
dataset_type: The type of dataset to use.
46+
dataset_path: The path to the dataset.
47+
48+
Returns:
49+
A list of strings representing the command line arguments.
50+
"""
3751
model_params = [
3852
"base_emb_dim=384",
3953
"base_num_query_heads=8",
@@ -71,7 +85,12 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
7185

7286

7387
def check_loss(metrics_file, target):
74-
"""Asserts over loss values from loaded checkpoint"""
88+
"""Asserts over loss values from loaded checkpoint.
89+
90+
Args:
91+
metrics_file: The base name of the metrics file.
92+
target: The target metric to check in the metrics file.
93+
"""
7594
metrics_file_saved = "saved_" + metrics_file
7695
metrics_file_restored = "restored_" + metrics_file
7796

@@ -89,7 +108,12 @@ def check_loss(metrics_file, target):
89108

90109

91110
def run_checkpointing(hardware, attention_type):
92-
"""Tests grain checkpoint determinism."""
111+
"""Tests checkpointing by saving and restoring a model.
112+
113+
Args:
114+
hardware: The hardware to run on.
115+
attention_type: The type of attention to use.
116+
"""
93117
run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
94118
grain_command = [
95119
"grain_worker_count=0",
@@ -127,10 +151,12 @@ def run_checkpointing(hardware, attention_type):
127151
@pytest.mark.integration_test
128152
@pytest.mark.tpu_only
129153
def test_autoselected_attention():
154+
"""Tests checkpointing with autoselected attention on TPU."""
130155
run_checkpointing("tpu", "autoselected")
131156

132157

133158
@pytest.mark.integration_test
134159
@pytest.mark.gpu_only
135160
def test_with_dot_product():
161+
"""Tests checkpointing with dot_product attention on GPU."""
136162
run_checkpointing("gpu", "dot_product")

0 commit comments

Comments
 (0)