Skip to content

Commit e63d001

Browse files
committed
Memory profiling
1 parent 108e831 commit e63d001

File tree

55 files changed

+142
-9
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+142
-9
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ tests/integration/core/__pycache__
66
tests/integration/core/__pycache__
77
tests/unit/core/__pycache__
88
.env
9+
*.pyc
10+
memory_profile.txt
-141 Bytes
Binary file not shown.
-158 Bytes
Binary file not shown.
Binary file not shown.
-158 Bytes
Binary file not shown.
Binary file not shown.
-257 Bytes
Binary file not shown.
-31 KB
Binary file not shown.

grizabella/mcp/server.py

Lines changed: 139 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import os
1515
import signal
1616
import sys
17+
import threading
18+
import time
1719
import uuid
1820
import tracemalloc
1921
from datetime import datetime, timezone
@@ -39,6 +41,15 @@
3941
)
4042
from grizabella.core.query_models import ComplexQuery, EmbeddingVector, QueryResult
4143

44+
# Set up logging configuration
45+
logging.basicConfig(
46+
level=logging.INFO,
47+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
48+
handlers=[
49+
logging.StreamHandler(sys.stdout)
50+
]
51+
)
52+
4253
logger = logging.getLogger(__name__)
4354

4455
# --- Configuration ---
@@ -795,14 +806,70 @@ async def mcp_get_embedding_vector_for_text(args: GetEmbeddingVectorForTextArgs)
795806
# `python -m fastmcp grizabella.mcp.server:app`
796807
# or similar, depending on FastMCP's conventions.
797808

809+
# Global variables for memory profiling
810+
snapshot1 = None
811+
args = None
812+
813+
def print_memory_stats(label="Memory Stats"):
814+
"""Print current memory statistics if profiling is enabled."""
815+
global snapshot1, args
816+
logger.info(f"print_memory_stats called with label: {label}")
817+
logger.info(f"args: {args}, profile_mem: {getattr(args, 'profile_mem', 'N/A') if args else 'N/A'}, snapshot1: {snapshot1 is not None}")
818+
if args and args.profile_mem and snapshot1:
819+
try:
820+
current_snapshot = tracemalloc.take_snapshot()
821+
top_stats = current_snapshot.compare_to(snapshot1, 'lineno')
822+
logger.info("[Current Top 10 memory consumers]")
823+
stats_output = []
824+
stats_output.append("[Current Top 10 memory consumers]")
825+
for stat in top_stats[:10]:
826+
logger.info(stat)
827+
stats_output.append(str(stat))
828+
829+
# Save to file if specified
830+
# Save to file if specified
831+
if args.profile_file:
832+
try:
833+
with open(args.profile_file, 'a') as f:
834+
f.write(f"\n--- {label} at {datetime.now()} ---\n")
835+
f.write('\n'.join(stats_output) + '\n')
836+
f.flush() # Ensure data is written immediately
837+
os.fsync(f.fileno()) # Force OS to write to disk
838+
except Exception as e:
839+
logger.error(f"Error saving memory stats to file: {e}")
840+
except Exception as e:
841+
logger.error(f"Error printing memory stats: {e}")
842+
798843
def shutdown_handler(signum, frame):
799844
"""Handle shutdown signals gracefully."""
800845
print(f"Received signal {signum}, shutting down...")
846+
logger.info(f"Received signal {signum}, shutting down...")
847+
# Print final memory stats before exiting
848+
print_memory_stats()
849+
# Ensure memory profile file is flushed
850+
global args
851+
logger.info(f"Shutdown handler args: {args}")
852+
if args and args.profile_mem and args.profile_file:
853+
try:
854+
# Try to flush any buffered writes
855+
logger.info(f"Writing shutdown signal to profile file: {args.profile_file}")
856+
with open(args.profile_file, 'a') as f:
857+
f.write(f"\n--- Shutdown Signal {signum} received at {datetime.now()} ---\n")
858+
f.flush()
859+
os.fsync(f.fileno())
860+
logger.info("Shutdown signal written to profile file successfully")
861+
# Small delay to ensure file operations complete
862+
time.sleep(0.1)
863+
except Exception as e:
864+
print(f"Error writing shutdown signal to profile file: {e}")
865+
logger.error(f"Error writing shutdown signal to profile file: {e}")
801866
# Perform any cleanup here if needed
802867
sys.exit(0)
803868

804869
def main():
805870
"""Initializes client and runs the FastMCP application."""
871+
global snapshot1, args
872+
806873
# Start memory tracing
807874
tracemalloc.start()
808875
snapshot1 = tracemalloc.take_snapshot()
@@ -814,33 +881,97 @@ def main():
814881
parser = argparse.ArgumentParser(description="Grizabella MCP Server")
815882
parser.add_argument("--db-path", help="Path to the Grizabella database.")
816883
parser.add_argument("--profile-mem", action="store_true", help="Enable memory profiling")
884+
parser.add_argument("--profile-interval", type=int, default=0, help="Interval in seconds to print memory stats (0 = disabled)")
885+
parser.add_argument("--profile-file", type=str, default="memory_profile.txt", help="File to save memory profiling data")
817886
args = parser.parse_args()
818887

819888
global grizabella_client_instance
820889
db_path = get_grizabella_db_path(args.db_path)
821890

822891
if args.profile_mem:
823892
logger.info("Memory profiling enabled")
824-
logger.info(f"Initial memory snapshot: {snapshot1.statistics('lineno')[:10]}")
825-
893+
initial_stats = snapshot1.statistics('lineno')[:10]
894+
logger.info(f"Initial memory snapshot: {initial_stats}")
895+
logger.info(f"Initial memory stats count: {len(initial_stats)}")
896+
897+
# Save initial snapshot to file
898+
if args.profile_file:
899+
try:
900+
with open(args.profile_file, 'w') as f:
901+
f.write(f"Initial memory snapshot at {datetime.now()}\n")
902+
for stat in snapshot1.statistics('lineno')[:10]:
903+
f.write(f"{stat}\n")
904+
logger.info(f"Initial memory snapshot saved to {args.profile_file}")
905+
except Exception as e:
906+
logger.error(f"Error saving initial memory snapshot: {e}")
907+
908+
# Start periodic memory reporting if interval is set (before app.run)
909+
memory_thread = None
910+
if args.profile_mem and args.profile_interval > 0:
911+
def periodic_memory_report():
912+
logger.info("Periodic memory reporting thread started")
913+
iteration = 0
914+
while True:
915+
try:
916+
iteration += 1
917+
logger.info(f"Periodic memory report #{iteration} triggered, sleeping for {args.profile_interval} seconds") # type: ignore
918+
time.sleep(args.profile_interval) # type: ignore
919+
logger.info(f"Periodic memory report #{iteration} running")
920+
print_memory_stats(f"Periodic Report #{iteration}")
921+
logger.info(f"Periodic memory report #{iteration} completed")
922+
except Exception as e:
923+
logger.error(f"Error in periodic memory reporting thread: {e}")
924+
# Continue running even if there's an error
925+
time.sleep(5) # Wait a bit before retrying
926+
927+
memory_thread = threading.Thread(target=periodic_memory_report, daemon=True)
928+
memory_thread.start()
929+
logger.info(f"Periodic memory reporting enabled every {args.profile_interval} seconds")
930+
826931
try:
827932
with Grizabella(db_name_or_path=db_path, create_if_not_exists=True) as gb:
828933
grizabella_client_instance = gb
934+
# Print memory stats before starting the server
935+
print_memory_stats("Before starting FastMCP server")
829936
app.run(show_banner=False)
937+
# Print memory stats after server stops (if it ever stops normally)
938+
print_memory_stats("After FastMCP server stopped")
830939
except Exception as e:
831940
print(f"Server error: {e}")
941+
# Print memory stats on error
942+
print_memory_stats()
832943
sys.exit(1)
833944
finally:
834945
# Ensure clean termination
835946
grizabella_client_instance = None
836947
print("Server terminated cleanly")
837948

838-
if args.profile_mem:
839-
snapshot2 = tracemalloc.take_snapshot()
840-
top_stats = snapshot2.compare_to(snapshot1, 'lineno')
841-
logger.info("[Top 10 memory differences]")
842-
for stat in top_stats[:10]:
843-
logger.info(stat)
949+
if args and args.profile_mem and snapshot1:
950+
try:
951+
logger.info("Taking final memory snapshot")
952+
snapshot2 = tracemalloc.take_snapshot()
953+
top_stats = snapshot2.compare_to(snapshot1, 'lineno')
954+
logger.info("[Top 10 memory differences]")
955+
stats_output = []
956+
stats_output.append("[Top 10 memory differences]")
957+
for stat in top_stats[:10]:
958+
logger.info(stat)
959+
stats_output.append(str(stat))
960+
961+
# Save to file if specified
962+
if args.profile_file:
963+
try:
964+
with open(args.profile_file, 'a') as f:
965+
f.write(f"\n--- Final Memory Differences at {datetime.now()} ---\n")
966+
f.write('\n'.join(stats_output) + '\n')
967+
f.write(f"Total memory allocated: {sum(stat.size for stat in top_stats)} bytes\n")
968+
f.flush()
969+
os.fsync(f.fileno())
970+
logger.info("Final memory differences written to file")
971+
except Exception as e:
972+
logger.error(f"Error writing final memory differences to file: {e}")
973+
except Exception as e:
974+
logger.error(f"Error in final memory profiling: {e}")
844975

845976
sys.exit(0)
846977

-144 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)