-
Notifications
You must be signed in to change notification settings - Fork 222
[tx] WIP: Add MaxText backend support #788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Introduces a clean separation between engine (orchestration) and backend (computation): **New files in `backends/`:** - `backend.py`: AbstractBackend interface defining the contract - `native.py`: NativeBackend implementation (extracted from engine.py) - `utils.py`: Shared utilities (log_timing, pad, pad_batch) - `__init__.py`: Module exports **Engine responsibilities (engine.py):** - Database operations (futures, checkpoints) - Request validation (`_filter_valid_requests`) - Data extraction from requests (`_prepare_model_pass_batch`, `_prepare_sample_batch`) - File I/O (checkpoint download/upload) - Orchestration of batch processing **Backend responsibilities (native.py):** - Model initialization and state management - JAX/Flax computation (forward, backward, gradient accumulation) - Optimizer creation and updates - Checkpoint data extraction/insertion **New types in types.py:** - `PreparedModelPassBatch`: Batch data for forward/backward ops - `PreparedSampleBatch`: Batch data for sampling ops This is a purely structural refactor - no functional changes to computation logic. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Added `metrics` property in `TinkerEngine` for backward compatibility with backend metrics. - Introduced `configure_adapter` method in `AbstractBackend` to streamline LoRA adapter configuration. - Updated `NativeBackend` to implement the new `configure_adapter` method, replacing the previous `update_adapter_config` call. These changes improve the modularity and maintainability of the codebase while ensuring compatibility with existing metrics functionality.
- Add maxtext.py backend implementation - Add maxtext_config_str field to EngineConfig - Engine selects backend based on config (MaxTextBackend if maxtext_config_str is set) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Add maxtext optional dependency from GitHub - Conditionally include --extra maxtext when maxtext_config_str is set - Add min_seq_len config parameter for padding buckets - Increase external inference timeout to 600s - Use shutil.move for cross-filesystem compatibility - Update requires-python to ==3.12.* 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
Converts MaxText LoRA tensors to HuggingFace PEFT format for compatibility with external inference engines. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a major refactoring to support multiple backends, with the addition of a MaxTextBackend. The core logic is abstracted into an AbstractBackend interface, and the existing implementation is moved to a NativeBackend. This is a great architectural improvement that enhances modularity. My review focuses on the correctness of the new MaxTextBackend, the refactoring in TinkerEngine, and consistency between the backends. I've found a few critical issues in the MaxTextBackend implementation, such as a method signature mismatch and a syntax error, that need to be addressed. I've also included suggestions for improving robustness, consistency, and code clarity.
| def process_forward_backward_batch( | ||
| self, | ||
| requests: dict[str, tuple[str, types.ForwardBackwardInput]], | ||
| models: dict[str, types.ModelMetadata], | ||
| ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature of process_forward_backward_batch does not match the one defined in the AbstractBackend interface. The TinkerEngine now prepares a PreparedModelPassBatch object and passes it to the backend. This implementation still uses the old signature and contains batch preparation logic that has been moved to the engine. This will cause a runtime error. The method should be updated to accept a prepared_batch: types.PreparedModelPassBatch and use the data from it, removing the redundant preparation logic.
| def _get_maxtext_base_config_path() -> str: | ||
| """Get the absolute path to MaxText's base.yml config file.""" | ||
| if not MAXTEXT_AVAILABLE: | ||
| return "" | ||
| import os | ||
| maxtext_pkg_dir = os.path.dirname(MaxText.__file__) | ||
| maxtext_root = os.path.dirname(os.path.dirname(maxtext_pkg_dir)) | ||
| config_path = os.path.join(maxtext_root, "src", "MaxText", "configs", "base.yml") | ||
| if not os.path.exists(config_path): | ||
| config_path = os.path.expanduser("~/maxtext/src/MaxText/configs/base.yml") | ||
| return config_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function _get_maxtext_base_config_path relies on hardcoded paths to find the base.yml config file. This is brittle and may fail if the maxtext package structure changes or if the user has a different setup. It would be more robust to make this path configurable, or to use a mechanism like importlib.resources if MaxText packages its config file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch
| with jax.set_mesh(self.mesh): | ||
| base_model, _ = maxtext_model_creation.create_nnx_model(maxtext_config, mesh=self.mesh) | ||
| self.model = TunixMaxTextAdapter(base_model=base_model) | ||
| self.model.config = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this model.config = None is copied from maxtext, i don't know why it's like this
| description = "Unified API for training and inference" | ||
| readme = "README.md" | ||
| requires-python = ">=3.11" | ||
| requires-python = "==3.12.*" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The requires-python has been changed to ==3.12.*, which is very restrictive. This will prevent users on other Python 3.12 patch versions or future minor versions from using the library. Unless there's a strong reason for this exact version, consider using a more flexible specifier like >=3.12.
| requires-python = "==3.12.*" | |
| requires-python = ">=3.12" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maxtext is annoying and requires python 3.12, actually, i haven't tested it with python 3.13
| def pad_batch(sequences: list[list], max_length: int, dtype, left: bool = False) -> jax.Array: | ||
| """Pad a batch of sequences to max_length.""" | ||
| batch_size = len(sequences) | ||
| padded = np.zeros((batch_size, max_length), dtype=dtype) | ||
| for i, seq in enumerate(sequences): | ||
| assert len(seq) <= max_length, f"Sequence length {len(seq)} exceeds max_length {max_length}" | ||
| if left: | ||
| padded[i, max_length - len(seq):] = seq | ||
| else: | ||
| padded[i, :len(seq)] = seq | ||
| return jnp.asarray(padded) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| print(f"MaxText forward-backward: batch [{mb_start}:{mb_end}], seq_len={seq_len}", flush=True) | ||
| tic = time.time() | ||
|
|
||
| _, target_logprobs, per_token_losses, grads = self._forward_backward( | ||
| self.model, | ||
| input_ids[mb_start:mb_end], | ||
| positions[mb_start:mb_end], | ||
| target_ids[mb_start:mb_end], | ||
| loss_mask[mb_start:mb_end], | ||
| ) | ||
|
|
||
| _ = jax.device_get(target_logprobs) | ||
|
|
||
| took = time.time() - tic | ||
| tokens_processed = (mb_end - mb_start) * seq_len | ||
| tokens_per_sec = tokens_processed / took if took > 0 else float('nan') | ||
| print(f"Batch [{mb_start}:{mb_end}] forward-backward time: {took:.3f} sec, tokens/sec: {tokens_per_sec:,.1f}", flush=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if num_layers is None: | ||
| # Try to infer from shapes | ||
| for tensor in lora_tensors.values(): | ||
| for dim in tensor.shape: | ||
| if dim in [36, 48, 64, 80, 96]: # Common layer counts | ||
| num_layers = dim | ||
| break | ||
| if num_layers: | ||
| break | ||
|
|
||
| if num_layers is None: | ||
| raise ValueError("Could not determine num_layers from tensor shapes") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to determine num_layers by checking for common layer counts in tensor shapes is brittle. This might fail for models with a different number of layers. To make this function more robust, consider adding num_layers as an optional argument. If it's not provided, you can fall back to this inference logic.
good catch Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Pin to 1edde4d1b1d562173d1753650b0234aa5c6a2fea for reproducible builds. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
…into maxtext_backend
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Try importlib.resources first for proper package resolution - Raise FileNotFoundError with helpful message instead of hardcoded fallback - Raise RuntimeError if MaxText is not available 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
- process_forward_backward_batch now accepts PreparedModelPassBatch - Add process_forward_batch stub with correct signature - process_sample_batch now accepts PreparedSampleBatch 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Implement checkpoint creation logic to handle existing checkpoints by deleting them before adding new ones. - Introduce model caching in TinkerEngine to manage least recently used models and optimize resource allocation. - Update backend methods to manage optimizers internally, removing the need for external optimizer handling. - Refactor model registration and unregistration processes across backends to streamline model lifecycle management. This update improves the efficiency of model management and checkpoint handling in the Tinker framework.
- Add a new function to reset LoRA weights, specifically zeroing out lora_b while preserving lora_a. - Update the unregister_model method to utilize the new reset_adapter_weights function for improved weight management. - Ensure model state is synchronized with updated LoRA parameters after unregistration. This change enhances the handling of model weights in the MaxText backend, improving the efficiency of model lifecycle management.
- Introduce a new TTL storage structure in ttl.md, including functions for setting and getting values with expiration, and a cleanup function for expired entries. - Update the NativeBackend class to enable eager sharding during model creation, enhancing performance and resource management. - Adjust loss calculation to prevent division by zero by using a maximum function on the loss mask. These changes improve memory management and model handling in the Tinker framework.
- Introduce a mechanism to evict old sampler checkpoints in the Tinker framework, ensuring that only a specified number of checkpoints are retained per model. - Update the MaxText backend to support dynamic loss function selection, allowing for more flexible loss calculations during training. - Integrate the `tenacity` library for retrying failed HTTP requests in the ExternalInferenceClient, improving robustness against transient errors. These changes enhance model management and error handling in the Tinker framework.
…ndling - Introduce a new helper function `_is_retryable_error` to determine if exceptions are retryable based on connection errors and 5xx HTTP status codes. - Update the retry decorator in the ExternalInferenceClient to utilize this new function, enhancing the robustness of HTTP request handling. These changes improve the error recovery mechanism in the Tinker framework.
(dependent on #787)
Summary
convert_maxtext_lora_to_hffunction for converting MaxText LoRA to HuggingFace PEFT formatmin_seq_lenconfig parameter for padding bucketsStatus
🚧 Work in Progress - Not ready for merge
Test plan
🤖 Generated with Claude Code