Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class NodeExecutionStatsSummary:
node_type: str
num_calls: int
time_used_seconds: float
peak_vram_gb: float
delta_vram_gb: float


@dataclass
Expand Down Expand Up @@ -58,10 +58,10 @@ class InvocationStatsSummary:
def __str__(self) -> str:
_str = ""
_str = f"Graph stats: {self.graph_stats.graph_execution_state_id}\n"
_str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Used':>10}\n"
_str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Change':+>10}\n"

for summary in self.node_stats:
_str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.peak_vram_gb:>9.3f}G\n"
_str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.delta_vram_gb:+10.3f}G\n"

_str += f"TOTAL GRAPH EXECUTION TIME: {self.graph_stats.execution_time_seconds:7.3f}s\n"

Expand Down Expand Up @@ -100,7 +100,7 @@ class NodeExecutionStats:
start_ram_gb: float # GB
end_ram_gb: float # GB

peak_vram_gb: float # GB
delta_vram_gb: float # GB

def total_time(self) -> float:
return self.end_time - self.start_time
Expand Down Expand Up @@ -174,9 +174,9 @@ def get_node_stats_summaries(self) -> list[NodeExecutionStatsSummary]:
for node_type, node_type_stats_list in node_stats_by_type.items():
num_calls = len(node_type_stats_list)
time_used = sum([n.total_time() for n in node_type_stats_list])
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
delta_vram = max([n.delta_vram_gb for n in node_type_stats_list])
summary = NodeExecutionStatsSummary(
node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, peak_vram_gb=peak_vram
node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, delta_vram_gb=delta_vram
)
summaries.append(summary)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st
# Record state before the invocation.
start_time = time.time()
start_ram = psutil.Process().memory_info().rss
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()

# Remember current VRAM usage
vram_in_use = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0.0

assert services.model_manager.load is not None
services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
Expand All @@ -62,14 +63,16 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st
# Let the invocation run.
yield None
finally:
# Record state after the invocation.
# Record delta VRAM
delta_vram_gb = ((torch.cuda.memory_allocated() - vram_in_use) / GB) if torch.cuda.is_available() else 0.0

node_stats = NodeExecutionStats(
invocation_type=invocation.get_type(),
start_time=start_time,
end_time=time.time(),
start_ram_gb=start_ram / GB,
end_ram_gb=psutil.Process().memory_info().rss / GB,
peak_vram_gb=torch.cuda.max_memory_allocated() / GB if torch.cuda.is_available() else 0.0,
delta_vram_gb=delta_vram_gb,
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)

Expand All @@ -81,6 +84,8 @@ def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
node_stats_summaries = self._get_node_summaries(graph_execution_state_id)
model_cache_stats_summary = self._get_model_cache_summary(graph_execution_state_id)
# Note: We use memory_allocated() here (not memory_reserved()) because we want to show
# the current actively-used VRAM, not the total reserved memory including PyTorch's cache.
vram_usage_gb = torch.cuda.memory_allocated() / GB if torch.cuda.is_available() else None

return InvocationStatsSummary(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def stats(self) -> Optional[CacheStats]:
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collecting cache statistics."""
self._stats = stats
# Populate the cache size in the stats object when it's set
if self._stats is not None:
self._stats.cache_size = self._ram_cache_size_bytes

def _record_activity(self) -> None:
"""Record model activity and reset the timeout timer if configured.
Expand Down
2 changes: 0 additions & 2 deletions invokeai/backend/util/vae_working_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def estimate_vae_working_memory_sd15_sdxl(
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20

print(f"estimate_vae_working_memory_sd15_sdxl: {int(working_memory)}")

return int(working_memory)


Expand Down