Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions point_transformer_v3/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*.nsys-rep
../panoptic_segmentation/ptv3
/tests/fvdb-test-data
fvdb-test-data
!requirements.txt
/tests
/data
/__pycache__/
12 changes: 8 additions & 4 deletions point_transformer_v3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@ This repository contains a minimal implementation of Point Transformer V3 using

## Environment

Use the FVDB default development environment:
Use the FVDB default development environment and install FVDB package:

```bash
cd fvdb/
conda env create -f env/dev_environment.yml
conda activate fvdb
./build.sh
```

Next, activate the environment and install additional dependancies specifically for the point transformer project.

```bash
cd fvdb/projects/point_transformer_v3
pip install -r requirements.txt
```

Expand All @@ -28,7 +32,7 @@ pip install -r requirements.txt

**Usage**:
```bash
python prepare_scannet_dataset.py --data_root /path/to/scannet --output_file scannet_samples.json --num_samples 10
python prepare_scannet_dataset.py --data_root /path/to/scannet --output_file scannet_samples.json --num_samples 16
```

**What it does**:
Expand Down Expand Up @@ -116,10 +120,10 @@ Run the PT-v3 model inference on the downloaded samples:

```bash
# Test with small dataset
python minimal_inference.py --data-path data/scannet_samples_small.json --voxel-size 0.1 --patch-size 1024
python minimal_inference.py --data-path data/scannet_samples_small.json --voxel-size 0.1 --patch-size 1024 --batch-size 1

# Test with large dataset
python minimal_inference.py --data-path data/scannet_samples_large.json --voxel-size 0.02 --patch-size 1024
python minimal_inference.py --data-path data/scannet_samples_large.json --voxel-size 0.02 --patch-size 1024 --batch-size 1
```

This will:
Expand Down
118 changes: 107 additions & 11 deletions point_transformer_v3/compute_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,39 @@
import numpy as np


def load_stats_file(filepath: str, logger: logging.Logger) -> List[Dict[str, Any]]:
def load_stats_file(filepath: str, logger: logging.Logger) -> tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Load and parse a minimal_inference_stats.json file.

Args:
filepath: Path to the JSON file to load.
logger: Logger instance for error reporting here.

Returns:
List of dictionaries containing the parsed JSON data.
Tuple of (per_sample_stats, global_stats) containing the parsed JSON data.
If the file has old format (just a list), returns (data, empty_dict).

Raises:
SystemExit: If file is not found or contains invalid JSON.
"""
try:
with open(filepath, "r") as f:
return json.load(f)
data = json.load(f)

# Handle both old format (list) and new format (dict with global_stats and per_sample_stats)
if isinstance(data, list):
# Old format - just a list of per-sample stats
logger.info(f"Loading old format file: {filepath}")
return data, {}
elif isinstance(data, dict) and "per_sample_stats" in data:
# New format - structured with global and per-sample stats
logger.info(f"Loading new format file: {filepath}")
global_stats = data.get("global_stats", {})
per_sample_stats = data.get("per_sample_stats", [])
return per_sample_stats, global_stats
else:
logger.error(f"Unexpected JSON structure in file '{filepath}'")
sys.exit(1)

except FileNotFoundError:
logger.error(f"File '{filepath}' not found.")
sys.exit(1)
Expand Down Expand Up @@ -64,6 +81,7 @@ def compute_deviations(
"num_points": {"absolute": [], "relative": []},
"output_feats_sum": {"absolute": [], "relative": []},
"output_feats_last_element": {"absolute": [], "relative": []},
"loss": {"absolute": [], "relative": []},
}

for i, (entry1, entry2) in enumerate(zip(stats1, stats2)):
Expand All @@ -72,6 +90,7 @@ def compute_deviations(
if field in entry1 and field in entry2:
val1 = entry1[field]
val2 = entry2[field]

if isinstance(val1, (int, float)) and isinstance(val2, (int, float)):
# Absolute difference
abs_deviation = abs(val1 - val2)
Expand Down Expand Up @@ -100,6 +119,67 @@ def compute_deviations(
return avg_deviations


def compute_global_deviations(
global_stats1: Dict[str, Any], global_stats2: Dict[str, Any], logger: logging.Logger
) -> Dict[str, Dict[str, float]]:
"""Compute deviations between global statistics from two files.

Args:
global_stats1: Global stats dictionary from the first file.
global_stats2: Global stats dictionary from the second file.
logger: Logger instance for warning messages.

Returns:
Dictionary containing deviations for global fields.
"""
global_deviations = {}

# Compare gradient vectors if present
if "first_module_grad_last16" in global_stats1 and "first_module_grad_last16" in global_stats2:

grad1 = global_stats1["first_module_grad_last16"]
grad2 = global_stats2["first_module_grad_last16"]

if isinstance(grad1, list) and isinstance(grad2, list) and len(grad1) == len(grad2):
# Compute L2 norm of the difference vector
diff_vec = [v1 - v2 for v1, v2 in zip(grad1, grad2)]
abs_deviation = np.sqrt(sum(d * d for d in diff_vec))

# Relative difference using L2 norms
norm1 = np.sqrt(sum(v * v for v in grad1))
norm2 = np.sqrt(sum(v * v for v in grad2))
if norm1 > 0 and norm2 > 0:
rel_deviation = abs_deviation / max(norm1, norm2)
else:
rel_deviation = 0.0

global_deviations["first_module_grad_last16"] = {"absolute": abs_deviation, "relative": rel_deviation}

logger.info(
f"Global gradient deviation: absolute={abs_deviation:.6f}, relative={rel_deviation:.6f} ({rel_deviation*100:.2f}%)"
)
else:
logger.warning("Gradient list format mismatch in global stats")

# Compare other numerical global fields
numerical_fields = ["total_samples", "batch_size"]
for field in numerical_fields:
if field in global_stats1 and field in global_stats2:
val1 = global_stats1[field]
val2 = global_stats2[field]

if isinstance(val1, (int, float)) and isinstance(val2, (int, float)):
abs_deviation = abs(val1 - val2)
if abs(val1) > 0 and abs(val2) > 0:
rel_deviation = abs_deviation / max(abs(val1), abs(val2))
else:
rel_deviation = 0.0

global_deviations[field] = {"absolute": abs_deviation, "relative": rel_deviation}

return global_deviations


def main():
parser = argparse.ArgumentParser(
description="Compute average deviation between two minimal_inference_stats.json files"
Expand All @@ -122,17 +202,22 @@ def main():
logger = logging.getLogger(__name__)

# Load both files
stats1 = load_stats_file(args.stats_path_1, logger)
stats2 = load_stats_file(args.stats_path_2, logger)
stats1, global_stats1 = load_stats_file(args.stats_path_1, logger)
stats2, global_stats2 = load_stats_file(args.stats_path_2, logger)

logger.info(f"File 1 has {len(stats1)} entries")
logger.info(f"File 2 has {len(stats2)} entries")
logger.info(f"File 1 has {len(stats1)} per-sample entries")
logger.info(f"File 2 has {len(stats2)} per-sample entries")

# Compute deviations
# Compute per-sample deviations
avg_deviations = compute_deviations(stats1, stats2, logger)

# Compute global deviations if both files have global stats
global_deviations = {}
if global_stats1 and global_stats2:
global_deviations = compute_global_deviations(global_stats1, global_stats2, logger)

# Print results
logger.info("\nAverage Deviations:")
logger.info("\nPer-Sample Average Deviations:")
logger.info("=" * 50)
for field, diff_types in avg_deviations.items():
logger.info(f"{field}:")
Expand All @@ -142,12 +227,23 @@ def main():
else:
logger.info(f" {diff_type:10s}: {avg_dev:.6f}")

# Compute overall average deviations
if global_deviations:
logger.info("\nGlobal Statistics Deviations:")
logger.info("=" * 50)
for field, diff_types in global_deviations.items():
logger.info(f"{field}:")
for diff_type, dev in diff_types.items():
if diff_type == "relative":
logger.info(f" {diff_type:10s}: {dev:.6f} ({dev*100:.2f}%)")
else:
logger.info(f" {diff_type:10s}: {dev:.6f}")

# Compute overall average deviations for per-sample stats
overall_absolute = np.mean([diff_types["absolute"] for diff_types in avg_deviations.values()])
overall_relative = np.mean([diff_types["relative"] for diff_types in avg_deviations.values()])
logger.info("=" * 50)

logger.info("\nOverall Averages:")
logger.info("\nOverall Per-Sample Averages:")
logger.info(f"Absolute: {overall_absolute:.6f}")
logger.info(f"Relative: {overall_relative:.6f} ({overall_relative*100:.2f}%)")

Expand Down
Loading