Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 238 additions & 0 deletions pufferlib/PMLL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# PMLL.py: Persistent Memory Logic Loop Implementation
# This Python module implements the PMLL architecture for memory-efficient LLM inference.
# It integrates with C extensions for high-performance operations, including SIMD-optimized
# routines (effectively intermixed Assembly via intrinsics).
# Requires: libpmlL_backend.so (compiled C library with SIMD intrinsics)
# Author: Based on PMLL Architecture Paper (2025)
# License: MIT

import ctypes
from ctypes import POINTER, c_int, c_float, c_void_p
from collections import deque
import numpy as np
import torch # Assuming PyTorch for Transformer integration
from typing import List, Dict, Any, Optional

# Load the C backend library (contains C code with SIMD/Assembly intrinsics)
try:
lib = ctypes.CDLL("./libpmlL_backend.so") # Adjust path as needed
except OSError:
raise ImportError("libpmlL_backend.so not found. Compile C extensions first.")

# C Type Definitions (mirroring C structs)
class MemoryPool(ctypes.Structure):
_fields_ = [
("size", c_int),
("data", POINTER(c_void_p)), # Pointer to array of entries
("utilization", c_float)
]

class PromiseQueue(ctypes.Structure):
_fields_ = [
("capacity", c_int),
("head", c_int),
("tail", c_int),
("promises", POINTER(c_void_p)) # Array of promise pointers
]

class Request(ctypes.Structure):
_fields_ = [
("type", c_int), # 0: READ, 1: WRITE
("id", c_int),
("data", c_void_p)
]

# C Function Signatures
lib.phi.argtypes = [c_int, c_int]
lib.phi.restype = c_int

lib.process_promise_queue.argtypes = [POINTER(PromiseQueue), POINTER(MemoryPool)]
lib.process_promise_queue.restype = POINTER(MemoryPool)

lib.vectorized_attention.argtypes = [POINTER(c_float), POINTER(c_float), POINTER(c_float), c_int]
lib.vectorized_attention.restype = None # In-place or returns via pointers

lib.trigger_compression.argtypes = [POINTER(MemoryPool), c_float]
lib.trigger_compression.restype = None

# Python Wrapper for phi (collision-free hashing)
def phi(id: int, n: int) -> int:
"""Collision-free slot assignment using modular arithmetic."""
return lib.phi(id, n)

# Memory Controller Class (Python orchestration with C calls)
class MemoryController:
def __init__(self, pool_size: int):
self.pool_size = pool_size
self.pool = [None] * pool_size # Python view; actual data in C pool
self.promise_queue = deque() # High-level queue; syncs with C
# Initialize C structures
self.c_pool = MemoryPool(pool_size, None, 0.0)
self.c_queue = PromiseQueue(pool_size, 0, 0, None)

def process_request(self, request: Dict[str, Any]) -> Optional[Any]:
"""Process read/write requests, delegating to C for performance."""
req_type = request["type"]
req_id = request["id"]

if req_type == "read":
slot = phi(req_id, self.pool_size)
# Call C for optimized read (uses SIMD for batch reads if applicable)
result = self._c_read(self.c_pool, slot)
return result
elif req_type == "write":
promise = self._create_promise(request)
self.promise_queue.append(promise)
# Enqueue in C queue for atomic processing
self._c_enqueue(self.c_queue, promise)
return None

def _c_read(self, pool: POINTER(MemoryPool), slot: int) -> Any:
# Placeholder: In full impl, extract from C pool.data[slot]
return self.pool[slot]

def _create_promise(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Create a promise with TTL and importance score."""
return {
"id": request["id"],
"data": request["data"],
"ttl": 3600, # Example TTL in seconds
"importance": np.random.rand() # Placeholder; use actual scoring
}

def _c_enqueue(self, queue: POINTER(PromiseQueue), promise: Dict[str, Any]):
# Serialize promise to C and enqueue (simplified)
pass # Full impl would use ctypes to pass data

def process_promise_queue(self):
"""Process the promise queue using C backend."""
lib.process_promise_queue(self.c_queue, self.c_pool)
# Sync Python queue if needed
while self.promise_queue:
promise = self.promise_queue.popleft()
if promise["ttl"] > 0:
slot = phi(promise["id"], self.pool_size)
self.pool[slot] = promise["data"]

def trigger_compression(self, rho: float = 0.1):
"""Trigger recursive compression via C routine."""
lib.trigger_compression(self.c_pool, c_float(rho))
# Python-side post-processing if needed
self.pool = self._recursive_compress(self.pool, rho)

def _recursive_compress(self, pool: List[Any], rho: float) -> List[Any]:
"""Python fallback for compression (C is primary)."""
if not pool:
return pool
scores = [np.random.rand() for _ in pool] # Placeholder importance scores
threshold = np.quantile(scores, rho)
compressed = []
for entry, score in zip(pool, scores):
if score >= threshold:
q = 8 if score > 0.8 else 4 # Bits for quantization
quantized = self._quantize(entry, q)
compressed.append(quantized)
return compressed

def _quantize(self, entry: Any, bits: int) -> Any:
"""Simple quantization placeholder."""
if isinstance(entry, np.ndarray):
return (entry * (2**bits - 1) / entry.max()).astype(np.int32)
return entry

# Custom PML Attention Mechanism (Hybrid local + persistent)
def pml_attention(Q: torch.Tensor, K_local: torch.Tensor, V_local: torch.Tensor,
memory_controller: MemoryController) -> torch.Tensor:
"""Hybrid attention: local + persistent memory retrieval."""
# Local attention
A_local = torch.softmax(Q @ K_local.T, dim=-1) @ V_local

# Retrieve relevant persistent memory via controller
M_relevant = memory_controller.retrieve_relevant(Q) # Impl: query pool
if M_relevant is None:
return A_local

# Extract K_p, V_p from persistent (use C for extraction if batched)
K_p, V_p = extract_keys_values(M_relevant)

# Persistent attention (vectorized via C if large)
if K_p.shape[0] > 32: # Threshold for C call
# Prepare arrays for C
q_ptr = Q.data_ptr()
k_ptr = K_p.data_ptr()
v_ptr = V_p.data_ptr()
d = Q.shape[-1]
lib.vectorized_attention(ctypes.cast(q_ptr, POINTER(c_float)),
ctypes.cast(k_ptr, POINTER(c_float)),
ctypes.cast(v_ptr, POINTER(c_float)),
c_int(d))
A_persistent = torch.softmax(Q @ K_p.T, dim=-1) @ V_p # Post-C
else:
A_persistent = torch.softmax(Q @ K_p.T, dim=-1) @ V_p

# Blending factor alpha (placeholder)
alpha = compute_alpha(Q, K_local, K_p)

return alpha * A_local + (1 - alpha) * A_persistent

def extract_keys_values(M_relevant: List[Any]) -> tuple[torch.Tensor, torch.Tensor]:
"""Extract keys and values from persistent memory entries."""
# Placeholder: assume M_relevant is list of (k,v) pairs
Ks = torch.stack([entry[0] for entry in M_relevant])
Vs = torch.stack([entry[1] for entry in M_relevant])
return Ks, Vs

def compute_alpha(Q: torch.Tensor, K_local: torch.Tensor, K_p: torch.Tensor) -> torch.Tensor:
"""Compute blending factor (e.g., based on similarity)."""
sim_local = torch.norm(Q - K_local.mean(0), dim=-1)
sim_p = torch.norm(Q - K_p.mean(0), dim=-1)
return torch.sigmoid(sim_local - sim_p).unsqueeze(-1)

# Retrieval helper for controller
def retrieve_relevant(Q: torch.Tensor, pool: List[Any]) -> Optional[List[Any]]:
"""Retrieve relevant entries from pool based on Q similarity."""
# Simplified cosine similarity; in prod, use FAISS or C-optimized search
relevant = [entry for entry in pool if entry is not None and cosine_sim(Q, entry[0]) > 0.5]
return relevant if relevant else None

def cosine_sim(a: torch.Tensor, b: torch.Tensor) -> float:
return torch.dot(a.flatten(), b.flatten()) / (torch.norm(a) * torch.norm(b))

# Example Usage / Integration with Transformer
class PMLLTransformer(torch.nn.Module):
def __init__(self, d_model: int, nhead: int, num_layers: int, pool_size: int = 1024):
super().__init__()
self.transformer = torch.nn.Transformer(d_model=d_model, nhead=nhead, num_layers=num_layers)
self.memory_controller = MemoryController(pool_size)
self.d_model = d_model

def forward(self, src: torch.Tensor, tgt: torch.Tensor, kv_cache: Optional[torch.Tensor] = None):
# Assume kv_cache is local (Q, K_local, V_local)
if kv_cache is None:
kv_cache = self.transformer.generate_square_subsequent_mask(src.size(0))

Q = self.transformer.encoder.layers[0].self_attn.in_proj_weight @ src # Simplified
K_local, V_local = kv_cache.split(self.d_model, dim=-1)

# Apply PML Attention
attn_output = pml_attention(Q, K_local, V_local, self.memory_controller)

# Update cache and persistent memory
new_kv = (Q, attn_output) # Placeholder
self.memory_controller.process_request({"type": "write", "id": hash(str(Q)), "data": new_kv})
self.memory_controller.process_promise_queue()

if self.memory_controller.c_pool.utilization > 0.8: # C-exposed utilization
self.memory_controller.trigger_compression()

return attn_output

# Main entry point for testing
if __name__ == "__main__":
# Example instantiation
model = PMLLTransformer(d_model=512, nhead=8, num_layers=6)
src = torch.rand(10, 32, 512) # batch=32, seq=10
output = model(src, src)
print(f"PMLL Output shape: {output.shape}")
print("PMLL initialized successfully. C/Assembly intermix via SIMD intrinsics.")
```​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​