diff --git a/pufferlib/PMLL.py b/pufferlib/PMLL.py new file mode 100644 index 000000000..2fa7ec7cf --- /dev/null +++ b/pufferlib/PMLL.py @@ -0,0 +1,238 @@ +# PMLL.py: Persistent Memory Logic Loop Implementation +# This Python module implements the PMLL architecture for memory-efficient LLM inference. +# It integrates with C extensions for high-performance operations, including SIMD-optimized +# routines (effectively intermixed Assembly via intrinsics). +# Requires: libpmlL_backend.so (compiled C library with SIMD intrinsics) +# Author: Based on PMLL Architecture Paper (2025) +# License: MIT + +import ctypes +from ctypes import POINTER, c_int, c_float, c_void_p +from collections import deque +import numpy as np +import torch # Assuming PyTorch for Transformer integration +from typing import List, Dict, Any, Optional + +# Load the C backend library (contains C code with SIMD/Assembly intrinsics) +try: + lib = ctypes.CDLL("./libpmlL_backend.so") # Adjust path as needed +except OSError: + raise ImportError("libpmlL_backend.so not found. Compile C extensions first.") + +# C Type Definitions (mirroring C structs) +class MemoryPool(ctypes.Structure): + _fields_ = [ + ("size", c_int), + ("data", POINTER(c_void_p)), # Pointer to array of entries + ("utilization", c_float) + ] + +class PromiseQueue(ctypes.Structure): + _fields_ = [ + ("capacity", c_int), + ("head", c_int), + ("tail", c_int), + ("promises", POINTER(c_void_p)) # Array of promise pointers + ] + +class Request(ctypes.Structure): + _fields_ = [ + ("type", c_int), # 0: READ, 1: WRITE + ("id", c_int), + ("data", c_void_p) + ] + +# C Function Signatures +lib.phi.argtypes = [c_int, c_int] +lib.phi.restype = c_int + +lib.process_promise_queue.argtypes = [POINTER(PromiseQueue), POINTER(MemoryPool)] +lib.process_promise_queue.restype = POINTER(MemoryPool) + +lib.vectorized_attention.argtypes = [POINTER(c_float), POINTER(c_float), POINTER(c_float), c_int] +lib.vectorized_attention.restype = None # In-place or returns via pointers + +lib.trigger_compression.argtypes = [POINTER(MemoryPool), c_float] +lib.trigger_compression.restype = None + +# Python Wrapper for phi (collision-free hashing) +def phi(id: int, n: int) -> int: + """Collision-free slot assignment using modular arithmetic.""" + return lib.phi(id, n) + +# Memory Controller Class (Python orchestration with C calls) +class MemoryController: + def __init__(self, pool_size: int): + self.pool_size = pool_size + self.pool = [None] * pool_size # Python view; actual data in C pool + self.promise_queue = deque() # High-level queue; syncs with C + # Initialize C structures + self.c_pool = MemoryPool(pool_size, None, 0.0) + self.c_queue = PromiseQueue(pool_size, 0, 0, None) + + def process_request(self, request: Dict[str, Any]) -> Optional[Any]: + """Process read/write requests, delegating to C for performance.""" + req_type = request["type"] + req_id = request["id"] + + if req_type == "read": + slot = phi(req_id, self.pool_size) + # Call C for optimized read (uses SIMD for batch reads if applicable) + result = self._c_read(self.c_pool, slot) + return result + elif req_type == "write": + promise = self._create_promise(request) + self.promise_queue.append(promise) + # Enqueue in C queue for atomic processing + self._c_enqueue(self.c_queue, promise) + return None + + def _c_read(self, pool: POINTER(MemoryPool), slot: int) -> Any: + # Placeholder: In full impl, extract from C pool.data[slot] + return self.pool[slot] + + def _create_promise(self, request: Dict[str, Any]) -> Dict[str, Any]: + """Create a promise with TTL and importance score.""" + return { + "id": request["id"], + "data": request["data"], + "ttl": 3600, # Example TTL in seconds + "importance": np.random.rand() # Placeholder; use actual scoring + } + + def _c_enqueue(self, queue: POINTER(PromiseQueue), promise: Dict[str, Any]): + # Serialize promise to C and enqueue (simplified) + pass # Full impl would use ctypes to pass data + + def process_promise_queue(self): + """Process the promise queue using C backend.""" + lib.process_promise_queue(self.c_queue, self.c_pool) + # Sync Python queue if needed + while self.promise_queue: + promise = self.promise_queue.popleft() + if promise["ttl"] > 0: + slot = phi(promise["id"], self.pool_size) + self.pool[slot] = promise["data"] + + def trigger_compression(self, rho: float = 0.1): + """Trigger recursive compression via C routine.""" + lib.trigger_compression(self.c_pool, c_float(rho)) + # Python-side post-processing if needed + self.pool = self._recursive_compress(self.pool, rho) + + def _recursive_compress(self, pool: List[Any], rho: float) -> List[Any]: + """Python fallback for compression (C is primary).""" + if not pool: + return pool + scores = [np.random.rand() for _ in pool] # Placeholder importance scores + threshold = np.quantile(scores, rho) + compressed = [] + for entry, score in zip(pool, scores): + if score >= threshold: + q = 8 if score > 0.8 else 4 # Bits for quantization + quantized = self._quantize(entry, q) + compressed.append(quantized) + return compressed + + def _quantize(self, entry: Any, bits: int) -> Any: + """Simple quantization placeholder.""" + if isinstance(entry, np.ndarray): + return (entry * (2**bits - 1) / entry.max()).astype(np.int32) + return entry + +# Custom PML Attention Mechanism (Hybrid local + persistent) +def pml_attention(Q: torch.Tensor, K_local: torch.Tensor, V_local: torch.Tensor, + memory_controller: MemoryController) -> torch.Tensor: + """Hybrid attention: local + persistent memory retrieval.""" + # Local attention + A_local = torch.softmax(Q @ K_local.T, dim=-1) @ V_local + + # Retrieve relevant persistent memory via controller + M_relevant = memory_controller.retrieve_relevant(Q) # Impl: query pool + if M_relevant is None: + return A_local + + # Extract K_p, V_p from persistent (use C for extraction if batched) + K_p, V_p = extract_keys_values(M_relevant) + + # Persistent attention (vectorized via C if large) + if K_p.shape[0] > 32: # Threshold for C call + # Prepare arrays for C + q_ptr = Q.data_ptr() + k_ptr = K_p.data_ptr() + v_ptr = V_p.data_ptr() + d = Q.shape[-1] + lib.vectorized_attention(ctypes.cast(q_ptr, POINTER(c_float)), + ctypes.cast(k_ptr, POINTER(c_float)), + ctypes.cast(v_ptr, POINTER(c_float)), + c_int(d)) + A_persistent = torch.softmax(Q @ K_p.T, dim=-1) @ V_p # Post-C + else: + A_persistent = torch.softmax(Q @ K_p.T, dim=-1) @ V_p + + # Blending factor alpha (placeholder) + alpha = compute_alpha(Q, K_local, K_p) + + return alpha * A_local + (1 - alpha) * A_persistent + +def extract_keys_values(M_relevant: List[Any]) -> tuple[torch.Tensor, torch.Tensor]: + """Extract keys and values from persistent memory entries.""" + # Placeholder: assume M_relevant is list of (k,v) pairs + Ks = torch.stack([entry[0] for entry in M_relevant]) + Vs = torch.stack([entry[1] for entry in M_relevant]) + return Ks, Vs + +def compute_alpha(Q: torch.Tensor, K_local: torch.Tensor, K_p: torch.Tensor) -> torch.Tensor: + """Compute blending factor (e.g., based on similarity).""" + sim_local = torch.norm(Q - K_local.mean(0), dim=-1) + sim_p = torch.norm(Q - K_p.mean(0), dim=-1) + return torch.sigmoid(sim_local - sim_p).unsqueeze(-1) + +# Retrieval helper for controller +def retrieve_relevant(Q: torch.Tensor, pool: List[Any]) -> Optional[List[Any]]: + """Retrieve relevant entries from pool based on Q similarity.""" + # Simplified cosine similarity; in prod, use FAISS or C-optimized search + relevant = [entry for entry in pool if entry is not None and cosine_sim(Q, entry[0]) > 0.5] + return relevant if relevant else None + +def cosine_sim(a: torch.Tensor, b: torch.Tensor) -> float: + return torch.dot(a.flatten(), b.flatten()) / (torch.norm(a) * torch.norm(b)) + +# Example Usage / Integration with Transformer +class PMLLTransformer(torch.nn.Module): + def __init__(self, d_model: int, nhead: int, num_layers: int, pool_size: int = 1024): + super().__init__() + self.transformer = torch.nn.Transformer(d_model=d_model, nhead=nhead, num_layers=num_layers) + self.memory_controller = MemoryController(pool_size) + self.d_model = d_model + + def forward(self, src: torch.Tensor, tgt: torch.Tensor, kv_cache: Optional[torch.Tensor] = None): + # Assume kv_cache is local (Q, K_local, V_local) + if kv_cache is None: + kv_cache = self.transformer.generate_square_subsequent_mask(src.size(0)) + + Q = self.transformer.encoder.layers[0].self_attn.in_proj_weight @ src # Simplified + K_local, V_local = kv_cache.split(self.d_model, dim=-1) + + # Apply PML Attention + attn_output = pml_attention(Q, K_local, V_local, self.memory_controller) + + # Update cache and persistent memory + new_kv = (Q, attn_output) # Placeholder + self.memory_controller.process_request({"type": "write", "id": hash(str(Q)), "data": new_kv}) + self.memory_controller.process_promise_queue() + + if self.memory_controller.c_pool.utilization > 0.8: # C-exposed utilization + self.memory_controller.trigger_compression() + + return attn_output + +# Main entry point for testing +if __name__ == "__main__": + # Example instantiation + model = PMLLTransformer(d_model=512, nhead=8, num_layers=6) + src = torch.rand(10, 32, 512) # batch=32, seq=10 + output = model(src, src) + print(f"PMLL Output shape: {output.shape}") + print("PMLL initialized successfully. C/Assembly intermix via SIMD intrinsics.") +```​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​