Skip to content

Conversation

@OhadRubin
Copy link
Contributor

@OhadRubin OhadRubin commented Dec 17, 2025

(dependent on #787)

Summary

  • Add MaxTextBackend as an alternative to NativeBackend for context parallelism support
  • Add convert_maxtext_lora_to_hf function for converting MaxText LoRA to HuggingFace PEFT format
  • Add maxtext optional dependency
  • Add min_seq_len config parameter for padding buckets
  • Improve external inference timeout and cross-filesystem compatibility

Status

🚧 Work in Progress - Not ready for merge

Test plan

  • Test MaxTextBackend initialization
  • Test LoRA conversion to HuggingFace format
  • End-to-end training with MaxText backend

🤖 Generated with Claude Code

OhadRubin and others added 5 commits December 17, 2025 08:17
Introduces a clean separation between engine (orchestration) and backend (computation):

**New files in `backends/`:**
- `backend.py`: AbstractBackend interface defining the contract
- `native.py`: NativeBackend implementation (extracted from engine.py)
- `utils.py`: Shared utilities (log_timing, pad, pad_batch)
- `__init__.py`: Module exports

**Engine responsibilities (engine.py):**
- Database operations (futures, checkpoints)
- Request validation (`_filter_valid_requests`)
- Data extraction from requests (`_prepare_model_pass_batch`, `_prepare_sample_batch`)
- File I/O (checkpoint download/upload)
- Orchestration of batch processing

**Backend responsibilities (native.py):**
- Model initialization and state management
- JAX/Flax computation (forward, backward, gradient accumulation)
- Optimizer creation and updates
- Checkpoint data extraction/insertion

**New types in types.py:**
- `PreparedModelPassBatch`: Batch data for forward/backward ops
- `PreparedSampleBatch`: Batch data for sampling ops

This is a purely structural refactor - no functional changes to computation logic.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Added `metrics` property in `TinkerEngine` for backward compatibility with backend metrics.
- Introduced `configure_adapter` method in `AbstractBackend` to streamline LoRA adapter configuration.
- Updated `NativeBackend` to implement the new `configure_adapter` method, replacing the previous `update_adapter_config` call.

These changes improve the modularity and maintainability of the codebase while ensuring compatibility with existing metrics functionality.
- Add maxtext.py backend implementation
- Add maxtext_config_str field to EngineConfig
- Engine selects backend based on config (MaxTextBackend if maxtext_config_str is set)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Add maxtext optional dependency from GitHub
- Conditionally include --extra maxtext when maxtext_config_str is set
- Add min_seq_len config parameter for padding buckets
- Increase external inference timeout to 600s
- Use shutil.move for cross-filesystem compatibility
- Update requires-python to ==3.12.*

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Converts MaxText LoRA tensors to HuggingFace PEFT format for
compatibility with external inference engines.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
@OhadRubin OhadRubin changed the title WIP: Add MaxText backend support [tx] WIP: Add MaxText backend support Dec 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a major refactoring to support multiple backends, with the addition of a MaxTextBackend. The core logic is abstracted into an AbstractBackend interface, and the existing implementation is moved to a NativeBackend. This is a great architectural improvement that enhances modularity. My review focuses on the correctness of the new MaxTextBackend, the refactoring in TinkerEngine, and consistency between the backends. I've found a few critical issues in the MaxTextBackend implementation, such as a method signature mismatch and a syntax error, that need to be addressed. I've also included suggestions for improving robustness, consistency, and code clarity.

Comment on lines 277 to 281
def process_forward_backward_batch(
self,
requests: dict[str, tuple[str, types.ForwardBackwardInput]],
models: dict[str, types.ModelMetadata],
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The signature of process_forward_backward_batch does not match the one defined in the AbstractBackend interface. The TinkerEngine now prepares a PreparedModelPassBatch object and passes it to the backend. This implementation still uses the old signature and contains batch preparation logic that has been moved to the engine. This will cause a runtime error. The method should be updated to accept a prepared_batch: types.PreparedModelPassBatch and use the data from it, removing the redundant preparation logic.

Comment on lines 33 to 43
def _get_maxtext_base_config_path() -> str:
"""Get the absolute path to MaxText's base.yml config file."""
if not MAXTEXT_AVAILABLE:
return ""
import os
maxtext_pkg_dir = os.path.dirname(MaxText.__file__)
maxtext_root = os.path.dirname(os.path.dirname(maxtext_pkg_dir))
config_path = os.path.join(maxtext_root, "src", "MaxText", "configs", "base.yml")
if not os.path.exists(config_path):
config_path = os.path.expanduser("~/maxtext/src/MaxText/configs/base.yml")
return config_path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function _get_maxtext_base_config_path relies on hardcoded paths to find the base.yml config file. This is brittle and may fail if the maxtext package structure changes or if the user has a different setup. It would be more robust to make this path configurable, or to use a mechanism like importlib.resources if MaxText packages its config file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

with jax.set_mesh(self.mesh):
base_model, _ = maxtext_model_creation.create_nnx_model(maxtext_config, mesh=self.mesh)
self.model = TunixMaxTextAdapter(base_model=base_model)
self.model.config = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting self.model.config = None is suspicious and could lead to AttributeErrors if other parts of the code expect a valid config object on the model. Please add a comment explaining why this is necessary or ensure that no code path relies on self.model.config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this model.config = None is copied from maxtext, i don't know why it's like this

description = "Unified API for training and inference"
readme = "README.md"
requires-python = ">=3.11"
requires-python = "==3.12.*"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The requires-python has been changed to ==3.12.*, which is very restrictive. This will prevent users on other Python 3.12 patch versions or future minor versions from using the library. Unless there's a strong reason for this exact version, consider using a more flexible specifier like >=3.12.

Suggested change
requires-python = "==3.12.*"
requires-python = ">=3.12"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxtext is annoying and requires python 3.12, actually, i haven't tested it with python 3.13

Comment on lines 59 to 69
def pad_batch(sequences: list[list], max_length: int, dtype, left: bool = False) -> jax.Array:
"""Pad a batch of sequences to max_length."""
batch_size = len(sequences)
padded = np.zeros((batch_size, max_length), dtype=dtype)
for i, seq in enumerate(sequences):
assert len(seq) <= max_length, f"Sequence length {len(seq)} exceeds max_length {max_length}"
if left:
padded[i, max_length - len(seq):] = seq
else:
padded[i, :len(seq)] = seq
return jnp.asarray(padded)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pad_batch function is duplicated here. A shared version has been added in tx.tinker.backends.utils. Please remove this duplicated implementation and import it from the utils module to avoid code duplication.

Comment on lines 336 to 352
print(f"MaxText forward-backward: batch [{mb_start}:{mb_end}], seq_len={seq_len}", flush=True)
tic = time.time()

_, target_logprobs, per_token_losses, grads = self._forward_backward(
self.model,
input_ids[mb_start:mb_end],
positions[mb_start:mb_end],
target_ids[mb_start:mb_end],
loss_mask[mb_start:mb_end],
)

_ = jax.device_get(target_logprobs)

took = time.time() - tic
tokens_processed = (mb_end - mb_start) * seq_len
tokens_per_sec = tokens_processed / took if took > 0 else float('nan')
print(f"Batch [{mb_start}:{mb_end}] forward-backward time: {took:.3f} sec, tokens/sec: {tokens_per_sec:,.1f}", flush=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are print statements on lines 336 and 352 that seem to be for debugging. These should be removed or replaced with logger calls for production code.

Comment on lines +326 to +338
if num_layers is None:
# Try to infer from shapes
for tensor in lora_tensors.values():
for dim in tensor.shape:
if dim in [36, 48, 64, 80, 96]: # Common layer counts
num_layers = dim
break
if num_layers:
break

if num_layers is None:
raise ValueError("Could not determine num_layers from tensor shapes")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to determine num_layers by checking for common layer counts in tensor shapes is brittle. This might fail for models with a different number of layers. To make this function more robust, consider adding num_layers as an optional argument. If it's not provided, you can fall back to this inference logic.

OhadRubin and others added 14 commits December 17, 2025 11:40
good catch

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Pin to 1edde4d1b1d562173d1753650b0234aa5c6a2fea for reproducible builds.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Try importlib.resources first for proper package resolution
- Raise FileNotFoundError with helpful message instead of hardcoded fallback
- Raise RuntimeError if MaxText is not available

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- process_forward_backward_batch now accepts PreparedModelPassBatch
- Add process_forward_batch stub with correct signature
- process_sample_batch now accepts PreparedSampleBatch

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Implement checkpoint creation logic to handle existing checkpoints by deleting them before adding new ones.
- Introduce model caching in TinkerEngine to manage least recently used models and optimize resource allocation.
- Update backend methods to manage optimizers internally, removing the need for external optimizer handling.
- Refactor model registration and unregistration processes across backends to streamline model lifecycle management.

This update improves the efficiency of model management and checkpoint handling in the Tinker framework.
- Add a new function to reset LoRA weights, specifically zeroing out lora_b while preserving lora_a.
- Update the unregister_model method to utilize the new reset_adapter_weights function for improved weight management.
- Ensure model state is synchronized with updated LoRA parameters after unregistration.

This change enhances the handling of model weights in the MaxText backend, improving the efficiency of model lifecycle management.
- Introduce a new TTL storage structure in ttl.md, including functions for setting and getting values with expiration, and a cleanup function for expired entries.
- Update the NativeBackend class to enable eager sharding during model creation, enhancing performance and resource management.
- Adjust loss calculation to prevent division by zero by using a maximum function on the loss mask.

These changes improve memory management and model handling in the Tinker framework.
- Introduce a mechanism to evict old sampler checkpoints in the Tinker framework, ensuring that only a specified number of checkpoints are retained per model.
- Update the MaxText backend to support dynamic loss function selection, allowing for more flexible loss calculations during training.
- Integrate the `tenacity` library for retrying failed HTTP requests in the ExternalInferenceClient, improving robustness against transient errors.

These changes enhance model management and error handling in the Tinker framework.
…ndling

- Introduce a new helper function `_is_retryable_error` to determine if exceptions are retryable based on connection errors and 5xx HTTP status codes.
- Update the retry decorator in the ExternalInferenceClient to utilize this new function, enhancing the robustness of HTTP request handling.

These changes improve the error recovery mechanism in the Tinker framework.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants