From b98cf786d0e69d7d2ab686aed90bba61d0b111e9 Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Thu, 18 Sep 2025 16:00:00 +0000 Subject: [PATCH 01/10] migrate the batch_ptv3 PR from previous repo. Signed-off-by: Hexu Zhao --- point_transformer_v3/.gitignore | 1 + point_transformer_v3/README.md | 12 +- point_transformer_v3/compute_difference.py | 118 ++++++++- point_transformer_v3/minimal_inference.py | 199 ++++++++------ point_transformer_v3/model.py | 242 +++++++++++------- .../prepare_scannet_dataset.py | 15 +- point_transformer_v3/requirements.txt | 3 +- 7 files changed, 386 insertions(+), 204 deletions(-) diff --git a/point_transformer_v3/.gitignore b/point_transformer_v3/.gitignore index 00f1999..0099b23 100644 --- a/point_transformer_v3/.gitignore +++ b/point_transformer_v3/.gitignore @@ -4,6 +4,7 @@ *.nsys-rep ../panoptic_segmentation/ptv3 /tests/fvdb-test-data +fvdb-test-data !requirements.txt /tests /data diff --git a/point_transformer_v3/README.md b/point_transformer_v3/README.md index 304b5e8..7a59693 100644 --- a/point_transformer_v3/README.md +++ b/point_transformer_v3/README.md @@ -4,15 +4,19 @@ This repository contains a minimal implementation of Point Transformer V3 using ## Environment -Use the FVDB default development environment: +Use the FVDB default development environment and install FVDB package: ```bash +cd fvdb/ conda env create -f env/dev_environment.yml +conda activate fvdb +./build.sh ``` Next, activate the environment and install additional dependancies specifically for the point transformer project. ```bash +cd fvdb/projects/point_transformer_v3 pip install -r requirements.txt ``` @@ -28,7 +32,7 @@ pip install -r requirements.txt **Usage**: ```bash -python prepare_scannet_dataset.py --data_root /path/to/scannet --output_file scannet_samples.json --num_samples 10 +python prepare_scannet_dataset.py --data_root /path/to/scannet --output_file scannet_samples.json --num_samples 16 ``` **What it does**: @@ -116,10 +120,10 @@ Run the PT-v3 model inference on the downloaded samples: ```bash # Test with small dataset -python minimal_inference.py --data-path data/scannet_samples_small.json --voxel-size 0.1 --patch-size 1024 +python minimal_inference.py --data-path data/scannet_samples_small.json --voxel-size 0.1 --patch-size 1024 --batch-size 1 # Test with large dataset -python minimal_inference.py --data-path data/scannet_samples_large.json --voxel-size 0.02 --patch-size 1024 +python minimal_inference.py --data-path data/scannet_samples_large.json --voxel-size 0.02 --patch-size 1024 --batch-size 1 ``` This will: diff --git a/point_transformer_v3/compute_difference.py b/point_transformer_v3/compute_difference.py index cd300dc..df75b1e 100644 --- a/point_transformer_v3/compute_difference.py +++ b/point_transformer_v3/compute_difference.py @@ -16,7 +16,7 @@ import numpy as np -def load_stats_file(filepath: str, logger: logging.Logger) -> List[Dict[str, Any]]: +def load_stats_file(filepath: str, logger: logging.Logger) -> tuple[List[Dict[str, Any]], Dict[str, Any]]: """Load and parse a minimal_inference_stats.json file. Args: @@ -24,14 +24,31 @@ def load_stats_file(filepath: str, logger: logging.Logger) -> List[Dict[str, Any logger: Logger instance for error reporting here. Returns: - List of dictionaries containing the parsed JSON data. + Tuple of (per_sample_stats, global_stats) containing the parsed JSON data. + If the file has old format (just a list), returns (data, empty_dict). Raises: SystemExit: If file is not found or contains invalid JSON. """ try: with open(filepath, "r") as f: - return json.load(f) + data = json.load(f) + + # Handle both old format (list) and new format (dict with global_stats and per_sample_stats) + if isinstance(data, list): + # Old format - just a list of per-sample stats + logger.info(f"Loading old format file: {filepath}") + return data, {} + elif isinstance(data, dict) and "per_sample_stats" in data: + # New format - structured with global and per-sample stats + logger.info(f"Loading new format file: {filepath}") + global_stats = data.get("global_stats", {}) + per_sample_stats = data.get("per_sample_stats", []) + return per_sample_stats, global_stats + else: + logger.error(f"Unexpected JSON structure in file '{filepath}'") + sys.exit(1) + except FileNotFoundError: logger.error(f"File '{filepath}' not found.") sys.exit(1) @@ -64,6 +81,7 @@ def compute_deviations( "num_points": {"absolute": [], "relative": []}, "output_feats_sum": {"absolute": [], "relative": []}, "output_feats_last_element": {"absolute": [], "relative": []}, + "loss": {"absolute": [], "relative": []}, } for i, (entry1, entry2) in enumerate(zip(stats1, stats2)): @@ -72,6 +90,7 @@ def compute_deviations( if field in entry1 and field in entry2: val1 = entry1[field] val2 = entry2[field] + if isinstance(val1, (int, float)) and isinstance(val2, (int, float)): # Absolute difference abs_deviation = abs(val1 - val2) @@ -100,6 +119,67 @@ def compute_deviations( return avg_deviations +def compute_global_deviations( + global_stats1: Dict[str, Any], global_stats2: Dict[str, Any], logger: logging.Logger +) -> Dict[str, Dict[str, float]]: + """Compute deviations between global statistics from two files. + + Args: + global_stats1: Global stats dictionary from the first file. + global_stats2: Global stats dictionary from the second file. + logger: Logger instance for warning messages. + + Returns: + Dictionary containing deviations for global fields. + """ + global_deviations = {} + + # Compare gradient vectors if present + if "first_module_grad_last16" in global_stats1 and "first_module_grad_last16" in global_stats2: + + grad1 = global_stats1["first_module_grad_last16"] + grad2 = global_stats2["first_module_grad_last16"] + + if isinstance(grad1, list) and isinstance(grad2, list) and len(grad1) == len(grad2): + # Compute L2 norm of the difference vector + diff_vec = [v1 - v2 for v1, v2 in zip(grad1, grad2)] + abs_deviation = np.sqrt(sum(d * d for d in diff_vec)) + + # Relative difference using L2 norms + norm1 = np.sqrt(sum(v * v for v in grad1)) + norm2 = np.sqrt(sum(v * v for v in grad2)) + if norm1 > 0 and norm2 > 0: + rel_deviation = abs_deviation / max(norm1, norm2) + else: + rel_deviation = 0.0 + + global_deviations["first_module_grad_last16"] = {"absolute": abs_deviation, "relative": rel_deviation} + + logger.info( + f"Global gradient deviation: absolute={abs_deviation:.6f}, relative={rel_deviation:.6f} ({rel_deviation*100:.2f}%)" + ) + else: + logger.warning("Gradient list format mismatch in global stats") + + # Compare other numerical global fields + numerical_fields = ["total_samples", "batch_size"] + for field in numerical_fields: + if field in global_stats1 and field in global_stats2: + val1 = global_stats1[field] + val2 = global_stats2[field] + + if isinstance(val1, (int, float)) and isinstance(val2, (int, float)): + abs_deviation = abs(val1 - val2) + if abs(val1) > 0 and abs(val2) > 0: + rel_deviation = abs_deviation / max(abs(val1), abs(val2)) + else: + rel_deviation = 0.0 + + global_deviations[field] = {"absolute": abs_deviation, "relative": rel_deviation} + + return global_deviations + + def main(): parser = argparse.ArgumentParser( description="Compute average deviation between two minimal_inference_stats.json files" @@ -122,17 +202,22 @@ def main(): logger = logging.getLogger(__name__) # Load both files - stats1 = load_stats_file(args.stats_path_1, logger) - stats2 = load_stats_file(args.stats_path_2, logger) + stats1, global_stats1 = load_stats_file(args.stats_path_1, logger) + stats2, global_stats2 = load_stats_file(args.stats_path_2, logger) - logger.info(f"File 1 has {len(stats1)} entries") - logger.info(f"File 2 has {len(stats2)} entries") + logger.info(f"File 1 has {len(stats1)} per-sample entries") + logger.info(f"File 2 has {len(stats2)} per-sample entries") - # Compute deviations + # Compute per-sample deviations avg_deviations = compute_deviations(stats1, stats2, logger) + # Compute global deviations if both files have global stats + global_deviations = {} + if global_stats1 and global_stats2: + global_deviations = compute_global_deviations(global_stats1, global_stats2, logger) + # Print results - logger.info("\nAverage Deviations:") + logger.info("\nPer-Sample Average Deviations:") logger.info("=" * 50) for field, diff_types in avg_deviations.items(): logger.info(f"{field}:") @@ -142,12 +227,23 @@ def main(): else: logger.info(f" {diff_type:10s}: {avg_dev:.6f}") - # Compute overall average deviations + if global_deviations: + logger.info("\nGlobal Statistics Deviations:") + logger.info("=" * 50) + for field, diff_types in global_deviations.items(): + logger.info(f"{field}:") + for diff_type, dev in diff_types.items(): + if diff_type == "relative": + logger.info(f" {diff_type:10s}: {dev:.6f} ({dev*100:.2f}%)") + else: + logger.info(f" {diff_type:10s}: {dev:.6f}") + + # Compute overall average deviations for per-sample stats overall_absolute = np.mean([diff_types["absolute"] for diff_types in avg_deviations.values()]) overall_relative = np.mean([diff_types["relative"] for diff_types in avg_deviations.values()]) logger.info("=" * 50) - logger.info("\nOverall Averages:") + logger.info("\nOverall Per-Sample Averages:") logger.info(f"Absolute: {overall_absolute:.6f}") logger.info(f"Relative: {overall_relative:.6f} ({overall_relative*100:.2f}%)") diff --git a/point_transformer_v3/minimal_inference.py b/point_transformer_v3/minimal_inference.py index 7490836..8946e28 100644 --- a/point_transformer_v3/minimal_inference.py +++ b/point_transformer_v3/minimal_inference.py @@ -126,35 +126,31 @@ def create_ptv3_model(args, device, num_classes): return model -def prepare_input_from_scannet_points(color, grid_coords, voxel_size=0.1, device="cuda"): - """Prepare input from scannet points. +def prepare_batched_inputs_from_scannet_points(batch_samples, voxel_size=0.1, device="cuda"): + """Prepare batched inputs from a list of ScanNet-like samples. Args: - color: Color of the points. - grid_coords: Grid coordinates of the points. - voxel_size: Voxel size for grid sampling. - device: Device to place the tensors on. + batch_samples: list of dicts with keys "grid_coords" and "color". + voxel_size: float + device: torch.device or str Returns: - grid: GridBatch of the given point cloud. - jfeats: JaggedTensor of the point cloud features. + grid: fvdb.GridBatch + jfeats: fvdb.JaggedTensor with concatenated [ijk, color] """ - # Convert to torch tensors - grid_coords_tensor = torch.tensor(grid_coords, device=device, dtype=torch.int32) - color_tensor = torch.tensor(color, device=device, dtype=torch.float32) - - # Create jagged tensor for grid coordinates - coords_jagged = fvdb.JaggedTensor([grid_coords_tensor]) - - # Create grid from coordinates - grid = fvdb.GridBatch.from_ijk(coords_jagged, voxel_sizes=[[voxel_size, voxel_size, voxel_size]], origins=[0.0] * 3) - color_jdata = fvdb.JaggedTensor([color_tensor]) - color_vdb_order = grid.inject_from_ijk(coords_jagged, color_jdata) - - # Create features tensor (coordinates + color) + coords_list = [torch.tensor(s["grid_coords"], device=device, dtype=torch.int32) for s in batch_samples] + colors_list = [torch.tensor(s["color"], device=device, dtype=torch.float32) for s in batch_samples] + + coords_jagged = fvdb.JaggedTensor(coords_list) + grid = fvdb.GridBatch.from_ijk( + coords_jagged, + voxel_sizes=[[voxel_size, voxel_size, voxel_size]] * len(coords_list), + origins=[0.0] * 3, + ) + color_jagged = fvdb.JaggedTensor(colors_list) + color_vdb_order = grid.inject_from_ijk(coords_jagged, color_jagged) jfeats = color_vdb_order.jdata jfeats = fvdb.jcat([grid.ijk.float(), jfeats], dim=1) - return grid, jfeats @@ -164,8 +160,9 @@ def main(): parser.add_argument( "--data-path", type=str, default="scannet_samples.json", help="Path to the scannet samples json file" ) - parser.add_argument("--voxel-size", type=float, default=0.1, help="Voxel size for grid sampling") - parser.add_argument("--patch-size", type=int, default=0, help="Maximum points per sample") + parser.add_argument("--voxel-size", type=float, default=0.02, help="Voxel size for grid sampling") + parser.add_argument("--patch-size", type=int, default=1024, help="Maximum points per sample") + parser.add_argument("--batch-size", type=int, default=1, help="Number of samples per forward pass") parser.add_argument( "--model-mode", type=str, default="encdec_multihead_large_droppath", help="The model configuration to choose." ) @@ -218,94 +215,128 @@ def main(): # Process each sample logger.info("Using fvdb-based ptv3 model.") statistics_to_save = [] - for sample_idx, sample in enumerate(scannet_data): - logger.info(f"--- Processing Sample {sample_idx + 1}/{len(scannet_data)} ---") + batch_size = int(args.batch_size) - # Extract data from sample - num_points = sample["num_points"] - grid_coords = np.array(sample["grid_coords"]) - color = np.array(sample["color"]) - label = np.array(sample["label"]) if "label" in sample else None - - logger.info(f"Sample {sample_idx + 1}: {num_points} points") + # Accumulate gradients across all batches + accumulated_grad_last16 = None + first_module_name_global = None + for batch_start in range(0, len(scannet_data), batch_size): + batch = scannet_data[batch_start : batch_start + batch_size] + logger.info(f"--- Processing Batch {batch_start//batch_size + 1} with {len(batch)} samples ---") # Run inference - logger.info("Running inference...") + logger.info("Running batched inference...") nvtx.range_push("inference") - nvtx.range_push("create_grid_from_points") - init_grid, init_feat = prepare_input_from_scannet_points(color, grid_coords, voxel_size=0.1, device=device) + nvtx.range_push("create_batched_grid_from_points") + init_grid, init_feat = prepare_batched_inputs_from_scannet_points( + batch, voxel_size=args.voxel_size, device=device + ) nvtx.range_pop() - grid, feats = model(init_grid, init_feat) # outputs is a dict with keys "grid" and "feats". It is not logits. + grid, feats = model(init_grid, init_feat) nvtx.range_pop() - # Test backward path + # Compute per-sample forward stats by splitting with offsets + offsets = feats.joffsets.to(device=feats.jdata.device, dtype=torch.int64) + num_samples_in_batch = offsets.numel() - 1 + + # Store per-sample forward statistics + for local_idx in range(num_samples_in_batch): + start = int(offsets[local_idx].item()) + end = int(offsets[local_idx + 1].item()) + j_slice = feats.jdata[start:end] + sample_dict = batch[local_idx] + + # Per-sample forward stats (independent of batch size) + statistics_to_save.append( + { + "num_points": int(sample_dict.get("num_points", end - start)), + "output_feats_shape": [int(end - start), int(j_slice.shape[1]) if j_slice.ndim == 2 else 0], + "output_feats_sum": float(j_slice.sum().item()) if j_slice.numel() > 0 else 0.0, + "output_feats_last_element": float(j_slice[-1, -1].item()) if (end - start) > 0 else 0.0, + "loss": float(j_slice.sum().item()) if j_slice.numel() > 0 else 0.0, # Per-sample loss + } + ) + + logger.info( + f"Sample {local_idx + 1}/{num_samples_in_batch}: feats.shape={j_slice.shape}, feats.sum()={j_slice.sum().item():.6f}, feats[last-element]={j_slice[-1, -1].item():.6f}" + ) + + # Test backward path - compute accumulated gradients for the entire batch logger.info("Testing backward path...") nvtx.range_push("backward") - - # Create a dummy loss (sum of output features) - loss = feats.jdata.sum() - - # Backward pass - loss.backward() + batch_loss = feats.jdata.sum() # Total loss for the entire batch + batch_loss.backward() nvtx.range_pop() - # Collect gradient statistics - grad_stats = {} - total_grad_norm = 0.0 - num_params_with_grad = 0 + # Collect gradient from the first module and accumulate across batches + batch_first_module_grad_last16 = None for name, param in model.named_parameters(): if param.grad is not None: - grad_norm = param.grad.norm().item() - grad_stats[f"{name}_grad_norm"] = grad_norm - grad_stats[f"{name}_grad_sum"] = param.grad.sum().item() - grad_stats[f"{name}_grad_last_element"] = param.grad.flatten()[-1].item() - total_grad_norm += grad_norm**2 - num_params_with_grad += 1 + # Get the last 16 gradients from the first module with gradients + if batch_first_module_grad_last16 is None: + if first_module_name_global is None: + first_module_name_global = name + + if name == first_module_name_global: + grad_flat = param.grad.flatten() + if len(grad_flat) >= 16: + batch_first_module_grad_last16 = grad_flat[-16:].tolist() + break + + # Accumulate gradients across all batches + if batch_first_module_grad_last16 is not None: + if accumulated_grad_last16 is None: + accumulated_grad_last16 = batch_first_module_grad_last16[:] else: - grad_stats[f"{name}_grad_norm"] = 0.0 - grad_stats[f"{name}_grad_sum"] = 0.0 - grad_stats[f"{name}_grad_last_element"] = 0.0 + accumulated_grad_last16 = [ + a + b for a, b in zip(accumulated_grad_last16, batch_first_module_grad_last16) + ] + + logger.info(f"Batch loss: {batch_loss.item():.6f}") + if batch_first_module_grad_last16: + logger.info( + f"Batch gradients from {first_module_name_global}: {[f'{x:.6f}' for x in batch_first_module_grad_last16[:4]]}...{[f'{x:.6f}' for x in batch_first_module_grad_last16[-4:]]}" + ) + if accumulated_grad_last16: + logger.info( + f"Accumulated gradients: {[f'{x:.6f}' for x in accumulated_grad_last16[:4]]}...{[f'{x:.6f}' for x in accumulated_grad_last16[-4:]]}" + ) - total_grad_norm = total_grad_norm**0.5 # L2 norm - - # Log the statistics of the output features and gradients - logger.info( - f"feats.shape: {feats.jdata.shape}. feats.sum(): {feats.jdata.sum().item()}. feats[last-element]: {feats.jdata[-1, -1].item()}" - ) - logger.info(f"Loss: {loss.item()}") - logger.info(f"Total gradient norm: {total_grad_norm}") - logger.info(f"Parameters with gradients: {num_params_with_grad}") - - statistics_to_save.append( - { - "num_points": num_points, - "output_feats_shape": feats.jdata.shape, - "output_feats_sum": feats.jdata.sum().item(), - "output_feats_last_element": feats.jdata[-1, -1].item(), - "loss": loss.item(), - "total_grad_norm": total_grad_norm, - "num_params_with_grad": num_params_with_grad, - "gradient_stats": grad_stats, - } - ) - - # Clear gradients for next iteration model.zero_grad() + # Create final output structure with global gradient info separate from per-sample stats + output_data = { + "global_stats": { + "first_module_grad_last16": accumulated_grad_last16, + "first_module_name": first_module_name_global, + "total_samples": len(statistics_to_save), + "batch_size": batch_size, + }, + "per_sample_stats": statistics_to_save, + } + # save the statistics to a json file output_file = args.data_path.replace(".json", f"_output.json") with open(output_file, "w") as f: - json.dump(statistics_to_save, f, indent=4) + json.dump(output_data, f, indent=4) logger.info(f"Statistics saved to {output_file}") + # Log final accumulated gradient summary + if accumulated_grad_last16: + grad_sum = sum(accumulated_grad_last16) + grad_norm = sum(x * x for x in accumulated_grad_last16) ** 0.5 + logger.info( + f"Final accumulated gradient from {first_module_name_global}: sum={grad_sum:.6f}, norm={grad_norm:.6f}" + ) + if __name__ == "__main__": main() ## Example commands: # scannet_samples_small.json -# python minimal_inference.py --data-path data/scannet_samples_small.json --voxel-size 0.1 --patch-size 1024 --model-mode encdec_multihead_large_droppath +# python minimal_inference.py --data-path data/scannet_samples_small.json --voxel-size 0.1 --patch-size 1024 --batch-size 1 # scannet_samples_large.json -# python minimal_inference.py --data-path data/scannet_samples_large.json --voxel-size 0.02 --patch-size 1024 --model-mode encdec_multihead_large_droppath +# python minimal_inference.py --data-path data/scannet_samples_large.json --voxel-size 0.02 --patch-size 1024 --batch-size 1 diff --git a/point_transformer_v3/model.py b/point_transformer_v3/model.py index 99ff573..edce826 100644 --- a/point_transformer_v3/model.py +++ b/point_transformer_v3/model.py @@ -200,92 +200,133 @@ def forward(self, grid, feats): qkv = self.qkv(feats_j) # (num_voxels, 3 * hidden_size) if self.sliding_window_attention and self.patch_size > 0: - # Perform sliding window attention using flash attention + # Perform sliding window attention per-grid using flash attention num_voxels = feats_j.shape[0] - qkv = qkv.view(1, num_voxels, 3, self.num_heads, self.head_dim) # (1, num_voxels, 3, num_heads, head_dim) - - window_size = (self.patch_size // 2, self.patch_size // 2) - - feats_out_j = flash_attn.flash_attn_qkvpacked_func( - qkv.half(), dropout_p=0.0, softmax_scale=1.0, window_size=window_size - ).reshape(num_voxels, self.hidden_size) + H = self.num_heads + D = self.head_dim + offsets = feats.joffsets.to(device=qkv.device, dtype=torch.int64) + outputs = [] + for b in range(offsets.numel() - 1): + start = int(offsets[b].item()) + end = int(offsets[b + 1].item()) + Li = end - start + if Li <= 0: + continue + qkv_b = qkv[start:end].view(1, Li, 3, H, D) + window_size = (self.patch_size // 2, self.patch_size // 2) + out_b = flash_attn.flash_attn_qkvpacked_func( + qkv_b.half(), dropout_p=0.0, softmax_scale=1.0, window_size=window_size + ).reshape(Li, self.hidden_size) + outputs.append(out_b) + if len(outputs) == 0: + feats_out_j = torch.empty_like(qkv[:, : self.hidden_size]) + else: + feats_out_j = torch.cat(outputs, dim=0) feats_out_j = feats_out_j.to(feats_j.dtype) elif self.patch_size > 0: - # Perform attention within each patch_size window. + # Perform attention within each patch_size window per-grid using varlen API num_voxels = feats_j.shape[0] - qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) # (num_voxels, 3, num_heads, head_dim) - cu_seqlens = torch.cat( - [ - torch.arange(0, num_voxels, self.patch_size, device=qkv.device, dtype=torch.int32), - torch.tensor([num_voxels], device=qkv.device, dtype=torch.int32), - ] - ) - - feats_out_j = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv.half(), cu_seqlens, max_seqlen=self.patch_size, dropout_p=0.0, softmax_scale=1.0 - ).reshape(num_voxels, self.hidden_size) - - if self.cross_patch_attention: - num_complete_patches = num_voxels // self.patch_size - remaining_voxels = num_voxels % self.patch_size - - if num_complete_patches > 1: # Only do cross-patch if we have multiple patches - complete_voxels = num_complete_patches * self.patch_size - qkv_complete = qkv[:complete_voxels] # (complete_voxels, 3, num_heads, head_dim) - - qkv_patches = qkv_complete.view( - num_complete_patches, self.patch_size, 3, self.num_heads, self.head_dim - ) + H = self.num_heads + D = self.head_dim + qkv = qkv.view(-1, 3, H, D) # (num_voxels, 3, num_heads, head_dim) + + # Build cu_seqlens as concatenation of per-grid patches so we never cross grid boundaries + offsets = feats.joffsets.to(device=qkv.device, dtype=torch.int64) + lengths = [] + for b in range(offsets.numel() - 1): + start = int(offsets[b].item()) + end = int(offsets[b + 1].item()) + Li = end - start + if Li <= 0: + continue + full = Li // self.patch_size + rem = Li % self.patch_size + if full > 0: + lengths.extend([self.patch_size] * full) + if rem > 0: + lengths.append(rem) + if len(lengths) == 0: + feats_out_j = torch.empty((0, self.hidden_size), device=qkv.device, dtype=feats_j.dtype) + else: + cu_seqlens = torch.zeros(len(lengths) + 1, device=qkv.device, dtype=torch.int32) + cu_seqlens[1:] = torch.as_tensor(lengths, device=qkv.device, dtype=torch.int32).cumsum(dim=0) + + feats_out_j = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv.half(), cu_seqlens, max_seqlen=self.patch_size, dropout_p=0.0, softmax_scale=1.0 + ).reshape(num_voxels, self.hidden_size) + + if self.cross_patch_attention: + # Apply cross-patch attention per-grid and add to feats_out_j + add_buf = torch.zeros_like(feats_out_j) + for b in range(offsets.numel() - 1): + start = int(offsets[b].item()) + end = int(offsets[b + 1].item()) + Li = end - start + if Li <= 0: + continue + num_complete_patches = Li // self.patch_size + remaining_voxels = Li % self.patch_size + + qkv_b = qkv[start:end] # (Li, 3, H, D) + if num_complete_patches > 0: + complete_voxels = num_complete_patches * self.patch_size + qkv_complete = qkv_b[:complete_voxels] + qkv_patches = qkv_complete.view(num_complete_patches, self.patch_size, 3, H, D) + if self.cross_patch_pooling == "mean": + qkv_pooled = qkv_patches.mean(dim=1) + elif self.cross_patch_pooling == "max": + qkv_pooled = qkv_patches.max(dim=1)[0] + else: + raise ValueError(f"Unsupported pooling method: {self.cross_patch_pooling}") + else: + qkv_pooled = None + complete_voxels = 0 - if self.cross_patch_pooling == "mean": - qkv_pooled = qkv_patches.mean(dim=1) # (num_complete_patches, 3, num_heads, head_dim) - elif self.cross_patch_pooling == "max": - qkv_pooled = qkv_patches.max(dim=1)[0] # (num_complete_patches, 3, num_heads, head_dim) - else: - raise ValueError(f"Unsupported pooling method: {self.cross_patch_pooling}") - - if remaining_voxels > 0: - qkv_remaining = qkv[complete_voxels:] # (remaining_voxels, 3, num_heads, head_dim) - if self.cross_patch_pooling == "mean": - qkv_remaining_pooled = qkv_remaining.mean( - dim=0, keepdim=True - ) # (1, 3, num_heads, head_dim) - else: # max pooling - qkv_remaining_pooled = qkv_remaining.max(dim=0, keepdim=True)[ - 0 - ] # (1, 3, num_heads, head_dim) - qkv_pooled = torch.cat( - [qkv_pooled, qkv_remaining_pooled], dim=0 - ) # (num_complete_patches + 1, 3, num_heads, head_dim) - num_total_patches = num_complete_patches + 1 - else: num_total_patches = num_complete_patches - - qkv_pooled_unsqueezed = qkv_pooled.unsqueeze(0) - cross_attn_out = flash_attn.flash_attn_qkvpacked_func( - qkv_pooled_unsqueezed.half(), dropout_p=0.0, softmax_scale=1.0 - ).reshape(num_total_patches, self.hidden_size) - - cross_attn_complete = cross_attn_out[:num_complete_patches] - cross_attn_expanded = cross_attn_complete.unsqueeze(1).expand( - -1, self.patch_size, -1 - ) # (num_complete_patches, patch_size, hidden_size) - cross_attn_flat = cross_attn_expanded.reshape( - complete_voxels, self.hidden_size - ) # (complete_voxels, hidden_size) - - cross_attn_all = torch.zeros_like(feats_out_j) - cross_attn_all[:complete_voxels] = cross_attn_flat.to(feats_out_j.dtype) - if remaining_voxels > 0: - cross_attn_all[complete_voxels:] = cross_attn_out[-1].unsqueeze(0).expand(remaining_voxels, -1) - - feats_out_j = feats_out_j + cross_attn_all - - feats_out_j = feats_out_j.to(feats_j.dtype) + if remaining_voxels > 0: + qkv_remaining = qkv_b[complete_voxels:] + if self.cross_patch_pooling == "mean": + qkv_remaining_pooled = qkv_remaining.mean(dim=0, keepdim=True) + else: + qkv_remaining_pooled = qkv_remaining.max(dim=0, keepdim=True)[0] + qkv_pooled = ( + qkv_remaining_pooled + if qkv_pooled is None + else torch.cat([qkv_pooled, qkv_remaining_pooled], dim=0) + ) + num_total_patches += 1 + + if num_total_patches > 0: + # qkv_pooled must be defined here because num_total_patches > 0 + assert qkv_pooled is not None + cross_attn_out = flash_attn.flash_attn_qkvpacked_func( + qkv_pooled.unsqueeze(0).half(), dropout_p=0.0, softmax_scale=1.0 + ).reshape(num_total_patches, self.hidden_size) + + cross_attn_all_b = torch.zeros( + (Li, self.hidden_size), device=qkv.device, dtype=cross_attn_out.dtype + ) + if num_complete_patches > 0: + cross_attn_complete = cross_attn_out[:num_complete_patches] + cross_attn_expanded = cross_attn_complete.unsqueeze(1).expand(-1, self.patch_size, -1) + cross_attn_flat = cross_attn_expanded.reshape(complete_voxels, self.hidden_size) + cross_attn_all_b[:complete_voxels] = cross_attn_flat + if remaining_voxels > 0: + cross_attn_all_b[complete_voxels:] = ( + cross_attn_out[-1].unsqueeze(0).expand(remaining_voxels, -1) + ) + + add_buf[start:end] = cross_attn_all_b.to(add_buf.dtype) + + # Apply cross-patch addition across the whole batch + feats_out_j = (feats_out_j + add_buf.to(feats_out_j.dtype)).to(feats_j.dtype) else: - assert False, "Only sliding window attention and patch attention are supported now. " + feats_out_j = qkv[:, : self.hidden_size].contiguous() + + # Ensure dtype matches original features before linear projection + feats_out_j = feats_out_j.to(feats_j.dtype) feats_out_j = self.proj(feats_out_j) feats_out_j = self.drop(feats_out_j) @@ -524,30 +565,34 @@ def __init__( sliding_window_attention and cross_patch_attention ), "sliding_window_attention and cross_patch_attention should not be used together." - self.embedding = PTV3_Embedding(input_dim, enc_channels[0]) + if len(enc_channels) > 0: + self.embedding = PTV3_Embedding(input_dim, enc_channels[0]) + else: + self.embedding = None self.num_stages = len(enc_depths) - self.enc = torch.nn.ModuleList() - enc_drop_path = [x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))] - for i in range(self.num_stages): - if i > 0: + if self.num_stages > 0: + self.enc = torch.nn.ModuleList() + enc_drop_path = [x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))] + for i in range(self.num_stages): + if i > 0: + self.enc.append( + PTV3_Pooling(kernel_size=2, in_channels=enc_channels[i - 1], out_channels=enc_channels[i]) + ) self.enc.append( - PTV3_Pooling(kernel_size=2, in_channels=enc_channels[i - 1], out_channels=enc_channels[i]) - ) - self.enc.append( - PTV3_Encoder( - enc_channels[i], - enc_depths[i], - enc_num_heads[i], - enc_drop_path[sum(enc_depths[:i]) : sum(enc_depths[: i + 1])], - proj_drop, - patch_size, - no_conv_in_cpe, - cross_patch_attention, - cross_patch_pooling, - sliding_window_attention, + PTV3_Encoder( + enc_channels[i], + enc_depths[i], + enc_num_heads[i], + enc_drop_path[sum(enc_depths[:i]) : sum(enc_depths[: i + 1])], + proj_drop, + patch_size, + no_conv_in_cpe, + cross_patch_attention, + cross_patch_pooling, + sliding_window_attention, + ) ) - ) # create decoder self.num_dec_stages = len(dec_depths) @@ -591,6 +636,9 @@ def __init__( def forward(self, grid, feats): nvtx.range_push("PTV3_Forward") + if self.embedding is None: + return grid, feats + grid, feats = self.embedding(grid, feats) layer_id = 0 diff --git a/point_transformer_v3/prepare_scannet_dataset.py b/point_transformer_v3/prepare_scannet_dataset.py index da928c7..0f35ed6 100644 --- a/point_transformer_v3/prepare_scannet_dataset.py +++ b/point_transformer_v3/prepare_scannet_dataset.py @@ -274,12 +274,12 @@ def main(): parser = argparse.ArgumentParser(description="Export ScanNet dataset samples") parser.add_argument("--data-root", required=True, help="ScanNet dataset root directory") parser.add_argument("--output", required=True, help="Output JSON file path") - parser.add_argument("--num-samples", type=int, default=10, help="Number of samples to export") + parser.add_argument("--num-samples", type=int, default=16, help="Number of samples to export") parser.add_argument("--split", default="train", choices=["train", "val", "test"], help="Dataset split to use") - parser.add_argument("--min-points", type=int, default=2048, help="Minimum points per sample") - parser.add_argument("--max-points", type=int, default=4096, help="Maximum points per sample") - parser.add_argument("--patch-size", type=int, default=0, help="Maximum points per sample") - parser.add_argument("--voxel-size", type=float, default=0.1, help="Voxel size for grid sampling") + parser.add_argument("--min-points", type=int, default=50000, help="Minimum points per sample") + parser.add_argument("--max-points", type=int, default=100000, help="Maximum points per sample") + parser.add_argument("--patch-size", type=int, default=1024, help="Maximum points per sample") + parser.add_argument("--voxel-size", type=float, default=0.02, help="Voxel size for grid sampling") args = parser.parse_args() @@ -305,7 +305,8 @@ def main(): main() # Create scannet_samples_small.json -# python prepare_scannet_dataset.py --data-root /home/hexuz/openvdb/fvdb/projects/sparse_attention/Pointcept/data/scannet --output tests/data/scannet_samples_small.json --num-samples 3 --split train --min-points 2048 --max-points 4096 --voxel-size 0.1 --patch-size 1024 +# python prepare_scannet_dataset.py --data-root /home/hexuz/openvdb/fvdb/projects/sparse_attention/Pointcept/data/scannet --output data/scannet_samples_small.json --num-samples 8 --split train --min-points 2048 --max-points 4096 --voxel-size 0.1 --patch-size 1024 # Create scannet_samples_large.json -# python prepare_scannet_dataset.py --data-root /home/hexuz/openvdb/fvdb/projects/sparse_attention/Pointcept/data/scannet --output tests/data/scannet_samples_large.json --num-samples 3 --split train --min-points 50000 --max-points 100000 --voxel-size 0.02 --patch-size 1024 +# python prepare_scannet_dataset.py --data-root /home/hexuz/openvdb/fvdb/projects/sparse_attention/Pointcept/data/scannet --output data/scannet_samples_large.json --num-samples 4 --split train --min-points 50000 --max-points 100000 --voxel-size 0.02 --patch-size 1024 + diff --git a/point_transformer_v3/requirements.txt b/point_transformer_v3/requirements.txt index 1224ee2..91b2155 100644 --- a/point_transformer_v3/requirements.txt +++ b/point_transformer_v3/requirements.txt @@ -1 +1,2 @@ -flash_attn==2.7.4.post1 +flash_attn==2.7.2.post1 +timm From cc01a23eaee1d1793fe8d37155508d458b7509a2 Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Thu, 18 Sep 2025 16:03:54 +0000 Subject: [PATCH 02/10] format Signed-off-by: Hexu Zhao --- point_transformer_v3/prepare_scannet_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/point_transformer_v3/prepare_scannet_dataset.py b/point_transformer_v3/prepare_scannet_dataset.py index 0f35ed6..09d0ebd 100644 --- a/point_transformer_v3/prepare_scannet_dataset.py +++ b/point_transformer_v3/prepare_scannet_dataset.py @@ -309,4 +309,3 @@ def main(): # Create scannet_samples_large.json # python prepare_scannet_dataset.py --data-root /home/hexuz/openvdb/fvdb/projects/sparse_attention/Pointcept/data/scannet --output data/scannet_samples_large.json --num-samples 4 --split train --min-points 50000 --max-points 100000 --voxel-size 0.02 --patch-size 1024 - From 9d6fcc3f2d9946186d03178e1516dff932d00637 Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Mon, 20 Oct 2025 18:38:59 +0000 Subject: [PATCH 03/10] (1) support multiple order (2) support batch mode (3) Signed-off-by: Hexu Zhao --- point_transformer_v3/model.py | 138 +++++++++++++++++++++++++++++----- 1 file changed, 121 insertions(+), 17 deletions(-) diff --git a/point_transformer_v3/model.py b/point_transformer_v3/model.py index edce826..4d5896d 100644 --- a/point_transformer_v3/model.py +++ b/point_transformer_v3/model.py @@ -1,7 +1,7 @@ # Copyright Contributors to the OpenVDB Project # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Tuple +from typing import Dict, Tuple, Union, List # Add NVTX import for profiling import flash_attn @@ -9,6 +9,7 @@ import torch.nn import torch.nn.functional as F from timm.layers import DropPath +from functools import partial import fvdb @@ -34,22 +35,60 @@ class PTV3_Embedding(torch.nn.Module): PTV3_Embedding for 3D point cloud embedding. """ - def __init__(self, in_channels, embed_channels): + def __init__(self, in_channels, embed_channels, norm_layer_module: torch.nn.Module = torch.nn.LayerNorm, embedding_mode: str = "linear"): """ Args: in_channels (int): Number of channels in the input features. embed_channels (int): Number of channels in the output features. + norm_layer_module (torch.nn.Module): Normalization layer module. + embedding_mode (str): Mode for the embedding layer, "linear" or "conv3x3", "conv5x5". """ super().__init__() - self.linear = torch.nn.Linear(in_channels, embed_channels) - self.norm = torch.nn.LayerNorm(embed_channels) + self.embedding_mode = embedding_mode + if embedding_mode == "linear": + self.embed = torch.nn.Linear(in_channels, embed_channels) + elif embedding_mode == "conv3x3": + # use fvdb's convolution + self.embed_conv3x3_1 = fvdb.nn.SparseConv3d(in_channels, embed_channels, kernel_size=3, stride=1, bias=False) + elif embedding_mode == "conv5x5": + # version 1: 3x3 + 3x3 + # Total Params: 27 × in_channels × embed_channels + 27 × embed_channels^2 + self.embed_conv3x3_1 = fvdb.nn.SparseConv3d(in_channels, embed_channels, kernel_size=3, stride=1, bias=False) + self.embed_conv3x3_2 = fvdb.nn.SparseConv3d(embed_channels, embed_channels, kernel_size=3, stride=1, bias=False) + # version 2: 5x5 (unsupported yet) + # Total Params: 125 × in_channels × embed_channels + # self.embed_conv5x5_1 = fvdb.nn.SparseConv3d(in_channels, embed_channels, kernel_size=5, stride=1) + else: + raise ValueError(f"Unsupported embedding mode: {embedding_mode}") + self.norm_layer = norm_layer_module(embed_channels) self.act_layer = torch.nn.GELU() def forward(self, grid, feats): nvtx.range_push("PTV3_Embedding") - jfeats = feats.jdata - jfeats = self.linear(jfeats) - jfeats = self.norm(jfeats) + + # Initialize kmap if not present + if not hasattr(grid, "kmap"): + grid.kmap = None + + if self.embedding_mode == "linear": + jfeats = feats.jdata + jfeats = self.embed(jfeats) + elif self.embedding_mode == "conv3x3": + # Use fvdb convolution with kmap handling + grid, feats, out_kmap = self.embed_conv3x3_1._dispatch_conv(feats, grid, grid.kmap, grid) + grid.kmap = out_kmap # update the kmap + jfeats = feats.jdata + # There is no bias in the convolution-based embedding layer. + elif self.embedding_mode == "conv5x5": + # First 3x3 convolution + grid, feats, out_kmap = self.embed_conv3x3_1._dispatch_conv(feats, grid, grid.kmap, grid) + grid.kmap = out_kmap # update the kmap + # Second 3x3 convolution + grid, feats, out_kmap = self.embed_conv3x3_2._dispatch_conv(feats, grid, grid.kmap, grid) + grid.kmap = out_kmap # update the kmap + jfeats = feats.jdata + + jfeats = self.norm_layer(jfeats) jfeats = self.act_layer(jfeats) feats = feats.jagged_like(jfeats) @@ -58,7 +97,7 @@ def forward(self, grid, feats): class PTV3_Pooling(torch.nn.Module): - def __init__(self, kernel_size: int = 2, in_channels: int = 64, out_channels: int = 64): + def __init__(self, kernel_size: int = 2, in_channels: int = 64, out_channels: int = 64, norm_layer_module: torch.nn.Module = torch.nn.LayerNorm): """ Args: kernel_size (int): Kernel size for the pooling operation. @@ -68,7 +107,7 @@ def __init__(self, kernel_size: int = 2, in_channels: int = 64, out_channels: in super().__init__() self.kernel_size = kernel_size self.proj = torch.nn.Linear(in_channels, out_channels) - self.ln_layer = torch.nn.LayerNorm(out_channels) + self.norm_layer = norm_layer_module(out_channels) self.act_layer = torch.nn.GELU() def forward(self, grid, feats): @@ -78,7 +117,7 @@ def forward(self, grid, feats): ds_feature, ds_grid = grid.max_pool(self.kernel_size, feats, stride=self.kernel_size, coarse_grid=None) ds_feature_j = ds_feature.jdata - ds_feature_j = self.ln_layer(ds_feature_j) + ds_feature_j = self.norm_layer(ds_feature_j) ds_feature_j = self.act_layer(ds_feature_j) ds_feature = ds_feature.jagged_like(ds_feature_j) nvtx.range_pop() @@ -87,7 +126,7 @@ def forward(self, grid, feats): class PTV3_Unpooling(torch.nn.Module): - def __init__(self, kernel_size: int = 2, in_channels: int = 64, out_channels: int = 64, skip_channels: int = 64): + def __init__(self, kernel_size: int = 2, in_channels: int = 64, out_channels: int = 64, skip_channels: int = 64, norm_layer_module: torch.nn.Module = torch.nn.LayerNorm): """ Args: kernel_size (int): Kernel size for the pooling operation. @@ -102,10 +141,10 @@ def __init__(self, kernel_size: int = 2, in_channels: int = 64, out_channels: in self.out_channels = out_channels self.proj = torch.nn.Linear(in_channels, out_channels) - self.norm = torch.nn.LayerNorm(out_channels) + self.norm = norm_layer_module(out_channels) self.act_layer = torch.nn.GELU() self.proj_skip = torch.nn.Linear(skip_channels, out_channels) - self.norm_skip = torch.nn.LayerNorm(out_channels) + self.norm_skip = norm_layer_module(out_channels) self.act_layer_skip = torch.nn.GELU() def forward(self, grid, feats, last_grid, last_feats): @@ -165,6 +204,7 @@ def __init__( cross_patch_attention: bool = False, cross_patch_pooling: str = "mean", sliding_window_attention: bool = False, + order_type: str = "vdb", ): """ Args: @@ -175,6 +215,7 @@ def __init__( cross_patch_attention (bool): Whether to use cross-patch attention. cross_patch_pooling (str): Pooling method for cross-patch attention ("mean" or "max"). sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). + order_type (str): The type of order of the points, "vdb" or "z". """ super().__init__() self.hidden_size = hidden_size @@ -186,17 +227,37 @@ def __init__( self.proj = torch.nn.Linear(hidden_size, hidden_size) self.drop = torch.nn.Dropout(proj_drop) self.patch_size = patch_size - + self.order_type = order_type self.cross_patch_attention = cross_patch_attention self.cross_patch_pooling = cross_patch_pooling # "mean" or "max" # Sliding window attention parameter self.sliding_window_attention = sliding_window_attention + def _permute(self, grid, order_type): + if order_type == "z": + return grid.permutation_morton() + elif order_type == "z-trans": + return grid.permutation_morton_zyx() + elif order_type == "hilbert": + return grid.permutation_hilbert() + elif order_type == "hilbert-trans": + return grid.permutation_hilbert_zyx() + else: + raise ValueError(f"Unsupported order type: {order_type}") + def forward(self, grid, feats): nvtx.range_push("PTV3_Attention") feats_j = feats.jdata + # import pdb; pdb.set_trace() + + if self.order_type != "vdb": + perm = self._permute(grid, self.order_type).jdata.squeeze(-1) # [num_voxels] + # Use torch.gather for permutation: expand perm to match feats_j dimensions + perm_expanded = perm.unsqueeze(-1).expand(-1, feats_j.shape[-1]) # [num_voxels, hidden_size] + feats_j = torch.gather(feats_j, 0, perm_expanded) + qkv = self.qkv(feats_j) # (num_voxels, 3 * hidden_size) if self.sliding_window_attention and self.patch_size > 0: @@ -328,6 +389,12 @@ def forward(self, grid, feats): # Ensure dtype matches original features before linear projection feats_out_j = feats_out_j.to(feats_j.dtype) + if self.order_type != "vdb": + perm_reverse = torch.empty_like(perm) + perm_reverse[perm] = torch.arange(perm.shape[0], device=perm.device) # [num_voxels] + perm_reverse_expanded = perm_reverse.unsqueeze(-1).expand(-1, feats_out_j.shape[-1]) # [num_voxels, hidden_size] + feats_out_j = torch.gather(feats_out_j, 0, perm_reverse_expanded) + feats_out_j = self.proj(feats_out_j) feats_out_j = self.drop(feats_out_j) feats_out = grid.jagged_like(feats_out_j) @@ -365,6 +432,7 @@ def forward(self, grid, feats): grid.kmap = None if not self.no_conv_in_cpe: + # import pdb; pdb.set_trace() grid, out_feature, out_kmap = self.cpe[0]._dispatch_conv(feats, grid, grid.kmap, grid) grid.kmap = out_kmap # update the kmap if self.cpe[0].bias is not None: @@ -393,6 +461,7 @@ def __init__( cross_patch_attention: bool = False, cross_patch_pooling: str = "mean", sliding_window_attention: bool = False, + order_type: str = "vdb", ): """ Args: @@ -405,6 +474,7 @@ def __init__( cross_patch_attention (bool): Whether to use cross-patch attention. cross_patch_pooling (str): Pooling method for cross-patch attention ("mean" or "max"). sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). + order_type (str): The type of order of the points: "vdb", "z", "z-trans", "hilbert", "hilbert-trans". """ super().__init__() # one attention and one mlp @@ -418,8 +488,10 @@ def __init__( cross_patch_attention, cross_patch_pooling, sliding_window_attention, + order_type, ) self.norm2 = torch.nn.Sequential(torch.nn.LayerNorm(hidden_size)) # norm2.0 + self.order_type = order_type self.mlp = PTV3_MLP(hidden_size, proj_drop) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity() @@ -465,6 +537,7 @@ def __init__( cross_patch_attention: bool = False, cross_patch_pooling: str = "mean", sliding_window_attention: bool = False, + order_type: str = "vdb", ): """ Args: @@ -478,6 +551,7 @@ def __init__( cross_patch_attention (bool): Whether to use cross-patch attention. cross_patch_pooling (str): Pooling method for cross-patch attention ("mean" or "max"). sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). + order_type (str): The type of order of the points, "vdb" or "z". """ super().__init__() self.depth = depth @@ -493,11 +567,12 @@ def __init__( cross_patch_attention, cross_patch_pooling, sliding_window_attention, + order_type, ) for i in range(depth) ] ) - + self.order_type = order_type def forward(self, grid, feats): for block in self.blocks: grid, feats = block(grid, feats) @@ -525,10 +600,13 @@ def __init__( patch_size: int = 0, drop_path: float = 0.3, proj_drop: float = 0.0, + enable_batch_norm: bool = False, + embedding_mode: str = "linear", no_conv_in_cpe: bool = False, cross_patch_attention: bool = False, cross_patch_pooling: str = "mean", sliding_window_attention: bool = False, + order_type: Union[str, List[str]] = "vdb", ) -> None: """ ptv3 for 3D point cloud segmentation. @@ -546,10 +624,15 @@ def __init__( patch_size (int): Patch size for patch attention. drop_path (float): Drop path rate for regularization. proj_drop (float): Dropout rate for MLP layers. + enable_batch_norm (bool): Whether to use batch normalization for the embedding, down pooling, and up pooling. + embedding_mode (bool): the mode for the embedding layer, "linear" or "conv3x3", "conv5x5". no_conv_in_cpe (bool): Whether to disable convolution in CPE. cross_patch_attention (bool): Whether to use cross-patch attention. cross_patch_pooling (str): Pooling method for cross-patch attention ("mean" or "max"). sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). + order_type (Union[str, List[str]]): The type of order of the points. Can be a single string ("vdb", "z", "z-trans", "hilbert", "hilbert-trans") + for all layers, or a list of strings for different layers. Each encoder and decoder stage will use + order_type[i % len(order_type)] where i is the stage index. """ super().__init__() self.num_classes = num_classes @@ -559,6 +642,18 @@ def __init__( self.cross_patch_attention = cross_patch_attention self.cross_patch_pooling = cross_patch_pooling self.sliding_window_attention = sliding_window_attention + + # Handle order_type: convert to list for uniform processing + if isinstance(order_type, str): + self.order_type_list = [order_type] + else: + self.order_type_list = order_type + self.order_type = order_type # Keep original for backward compatibility + + if not enable_batch_norm: + self.norm_layer = torch.nn.LayerNorm + else: + self.norm_layer = partial(torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01) # sliding_window_attention and cross_patch_attention should not be used together. assert not ( @@ -566,7 +661,7 @@ def __init__( ), "sliding_window_attention and cross_patch_attention should not be used together." if len(enc_channels) > 0: - self.embedding = PTV3_Embedding(input_dim, enc_channels[0]) + self.embedding = PTV3_Embedding(input_dim, enc_channels[0], norm_layer_module=self.norm_layer, embedding_mode=embedding_mode) else: self.embedding = None @@ -577,8 +672,10 @@ def __init__( for i in range(self.num_stages): if i > 0: self.enc.append( - PTV3_Pooling(kernel_size=2, in_channels=enc_channels[i - 1], out_channels=enc_channels[i]) + PTV3_Pooling(kernel_size=2, in_channels=enc_channels[i - 1], out_channels=enc_channels[i], norm_layer_module=self.norm_layer) ) + # Select order_type for this encoder stage using modulo + stage_order_type = self.order_type_list[i % len(self.order_type_list)] self.enc.append( PTV3_Encoder( enc_channels[i], @@ -591,6 +688,7 @@ def __init__( cross_patch_attention, cross_patch_pooling, sliding_window_attention, + stage_order_type, ) ) @@ -616,8 +714,13 @@ def __init__( in_channels=last_channels, out_channels=dec_channels[i], skip_channels=enc_channels[self.num_stages - 2 - i], + norm_layer_module=self.norm_layer, ) ) + # Select order_type for this decoder stage using modulo + # Use reverse order for decoder (from last encoder stage backwards) + dec_stage_idx = self.num_stages - 1 - i + stage_order_type = self.order_type_list[dec_stage_idx % len(self.order_type_list)] self.dec.append( PTV3_Encoder( dec_channels[i], @@ -630,6 +733,7 @@ def __init__( cross_patch_attention, cross_patch_pooling, sliding_window_attention, + stage_order_type, ) ) From bd9fc907f154410aaef4bc6e03b28bb1c66112ba Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Mon, 20 Oct 2025 14:12:39 -0700 Subject: [PATCH 04/10] (1) clean the code (2) format. Signed-off-by: Hexu Zhao --- point_transformer_v3/model.py | 117 +++++++++++++++++++++++----------- 1 file changed, 79 insertions(+), 38 deletions(-) diff --git a/point_transformer_v3/model.py b/point_transformer_v3/model.py index 4d5896d..3f21f02 100644 --- a/point_transformer_v3/model.py +++ b/point_transformer_v3/model.py @@ -35,57 +35,73 @@ class PTV3_Embedding(torch.nn.Module): PTV3_Embedding for 3D point cloud embedding. """ - def __init__(self, in_channels, embed_channels, norm_layer_module: torch.nn.Module = torch.nn.LayerNorm, embedding_mode: str = "linear"): + def __init__( + self, + in_channels, + embed_channels, + norm_layer_module: torch.nn.Module = torch.nn.LayerNorm, + embedding_mode: str = "linear", + ): """ Args: in_channels (int): Number of channels in the input features. embed_channels (int): Number of channels in the output features. norm_layer_module (torch.nn.Module): Normalization layer module. - embedding_mode (str): Mode for the embedding layer, "linear" or "conv3x3", "conv5x5". + embedding_mode (str): The type of embedding layer, "linear" or "conv3x3", "conv5x5". """ super().__init__() self.embedding_mode = embedding_mode + if embedding_mode == "linear": self.embed = torch.nn.Linear(in_channels, embed_channels) elif embedding_mode == "conv3x3": - # use fvdb's convolution - self.embed_conv3x3_1 = fvdb.nn.SparseConv3d(in_channels, embed_channels, kernel_size=3, stride=1, bias=False) + # Initialize embedding using FVDB's sparse 3D convolution + self.embed_conv3x3_1 = fvdb.nn.SparseConv3d( + in_channels, embed_channels, kernel_size=3, stride=1, bias=False + ) elif embedding_mode == "conv5x5": - # version 1: 3x3 + 3x3 - # Total Params: 27 × in_channels × embed_channels + 27 × embed_channels^2 - self.embed_conv3x3_1 = fvdb.nn.SparseConv3d(in_channels, embed_channels, kernel_size=3, stride=1, bias=False) - self.embed_conv3x3_2 = fvdb.nn.SparseConv3d(embed_channels, embed_channels, kernel_size=3, stride=1, bias=False) - # version 2: 5x5 (unsupported yet) - # Total Params: 125 × in_channels × embed_channels + ## Implementation Option 1: Cascaded 3x3 convolutions + # This approach uses two 3x3 convs to achieve a 5x5 receptive field with fewer parameters + # Parameters: (27 × in_channels × embed_channels) + (27 × embed_channels²) + self.embed_conv3x3_1 = fvdb.nn.SparseConv3d( + in_channels, embed_channels, kernel_size=3, stride=1, bias=False + ) + self.embed_conv3x3_2 = fvdb.nn.SparseConv3d( + embed_channels, embed_channels, kernel_size=3, stride=1, bias=False + ) + + ## Implementation Option 2: Direct 5x5 convolution + # TODO: Implementation pending - requires additional sparse convolution support from fVDB-core. + # Expected parameters: 125 × in_channels × embed_channels # self.embed_conv5x5_1 = fvdb.nn.SparseConv3d(in_channels, embed_channels, kernel_size=5, stride=1) else: raise ValueError(f"Unsupported embedding mode: {embedding_mode}") + self.norm_layer = norm_layer_module(embed_channels) self.act_layer = torch.nn.GELU() def forward(self, grid, feats): nvtx.range_push("PTV3_Embedding") - - # Initialize kmap if not present + + # Initialize kernel map (kmap) for sparse convolution operations + # kmap tracks the mapping between input and output features during sparse convolutions if not hasattr(grid, "kmap"): grid.kmap = None - + if self.embedding_mode == "linear": jfeats = feats.jdata jfeats = self.embed(jfeats) elif self.embedding_mode == "conv3x3": - # Use fvdb convolution with kmap handling + # Apply 3x3 sparse convolution while maintaining kernel mapping + # Note: Bias is intentionally disabled to maintain consistency with standard transformer architectures grid, feats, out_kmap = self.embed_conv3x3_1._dispatch_conv(feats, grid, grid.kmap, grid) - grid.kmap = out_kmap # update the kmap + grid.kmap = out_kmap jfeats = feats.jdata - # There is no bias in the convolution-based embedding layer. elif self.embedding_mode == "conv5x5": - # First 3x3 convolution grid, feats, out_kmap = self.embed_conv3x3_1._dispatch_conv(feats, grid, grid.kmap, grid) - grid.kmap = out_kmap # update the kmap - # Second 3x3 convolution + grid.kmap = out_kmap grid, feats, out_kmap = self.embed_conv3x3_2._dispatch_conv(feats, grid, grid.kmap, grid) - grid.kmap = out_kmap # update the kmap + grid.kmap = out_kmap jfeats = feats.jdata jfeats = self.norm_layer(jfeats) @@ -97,7 +113,13 @@ def forward(self, grid, feats): class PTV3_Pooling(torch.nn.Module): - def __init__(self, kernel_size: int = 2, in_channels: int = 64, out_channels: int = 64, norm_layer_module: torch.nn.Module = torch.nn.LayerNorm): + def __init__( + self, + kernel_size: int = 2, + in_channels: int = 64, + out_channels: int = 64, + norm_layer_module: torch.nn.Module = torch.nn.LayerNorm, + ): """ Args: kernel_size (int): Kernel size for the pooling operation. @@ -126,7 +148,14 @@ def forward(self, grid, feats): class PTV3_Unpooling(torch.nn.Module): - def __init__(self, kernel_size: int = 2, in_channels: int = 64, out_channels: int = 64, skip_channels: int = 64, norm_layer_module: torch.nn.Module = torch.nn.LayerNorm): + def __init__( + self, + kernel_size: int = 2, + in_channels: int = 64, + out_channels: int = 64, + skip_channels: int = 64, + norm_layer_module: torch.nn.Module = torch.nn.LayerNorm, + ): """ Args: kernel_size (int): Kernel size for the pooling operation. @@ -149,8 +178,9 @@ def __init__(self, kernel_size: int = 2, in_channels: int = 64, out_channels: in def forward(self, grid, feats, last_grid, last_feats): - feats_j = self.proj(feats.jdata) - # BUG: When enabled AMP within Pointcept training pipeline, despite both the input and weights are float32, the output becomes float16. + feats_j = self.proj( + feats.jdata + ) # BUG: When enabled AMP, despite both feats.jdata and linear.weights are float32, the output becomes float16 which causes the subsequent convolution operation to fail. feats_j = self.norm(feats_j) feats_j = self.act_layer(feats_j) @@ -162,7 +192,9 @@ def forward(self, grid, feats, last_grid, last_feats): feats_j = feats.jdata new_feats_j = last_feats_j + feats_j - last_grid.kmap = None # the topology of the last grid is not valid anymore. + last_grid.kmap = ( + None # Because of the pooling operation, the previous kmap for convolution is not valid anymore. + ) return last_grid, last_grid.jagged_like(new_feats_j) @@ -250,16 +282,16 @@ def forward(self, grid, feats): nvtx.range_push("PTV3_Attention") feats_j = feats.jdata - # import pdb; pdb.set_trace() - if self.order_type != "vdb": - perm = self._permute(grid, self.order_type).jdata.squeeze(-1) # [num_voxels] + perm = self._permute(grid, self.order_type).jdata.squeeze(-1) # [num_voxels] # Use torch.gather for permutation: expand perm to match feats_j dimensions - perm_expanded = perm.unsqueeze(-1).expand(-1, feats_j.shape[-1]) # [num_voxels, hidden_size] + perm_expanded = perm.unsqueeze(-1).expand(-1, feats_j.shape[-1]) # [num_voxels, hidden_size] feats_j = torch.gather(feats_j, 0, perm_expanded) qkv = self.qkv(feats_j) # (num_voxels, 3 * hidden_size) + # TODO: only keep the sliding window attention and the default window attention. + if self.sliding_window_attention and self.patch_size > 0: # Perform sliding window attention per-grid using flash attention num_voxels = feats_j.shape[0] @@ -391,8 +423,10 @@ def forward(self, grid, feats): if self.order_type != "vdb": perm_reverse = torch.empty_like(perm) - perm_reverse[perm] = torch.arange(perm.shape[0], device=perm.device) # [num_voxels] - perm_reverse_expanded = perm_reverse.unsqueeze(-1).expand(-1, feats_out_j.shape[-1]) # [num_voxels, hidden_size] + perm_reverse[perm] = torch.arange(perm.shape[0], device=perm.device) # [num_voxels] + perm_reverse_expanded = perm_reverse.unsqueeze(-1).expand( + -1, feats_out_j.shape[-1] + ) # [num_voxels, hidden_size] feats_out_j = torch.gather(feats_out_j, 0, perm_reverse_expanded) feats_out_j = self.proj(feats_out_j) @@ -432,7 +466,6 @@ def forward(self, grid, feats): grid.kmap = None if not self.no_conv_in_cpe: - # import pdb; pdb.set_trace() grid, out_feature, out_kmap = self.cpe[0]._dispatch_conv(feats, grid, grid.kmap, grid) grid.kmap = out_kmap # update the kmap if self.cpe[0].bias is not None: @@ -477,7 +510,7 @@ def __init__( order_type (str): The type of order of the points: "vdb", "z", "z-trans", "hilbert", "hilbert-trans". """ super().__init__() - # one attention and one mlp + self.cpe = PTV3_CPE(hidden_size, no_conv_in_cpe) self.norm1 = torch.nn.Sequential(torch.nn.LayerNorm(hidden_size)) # norm1.0 self.attn = PTV3_Attention( @@ -573,6 +606,7 @@ def __init__( ] ) self.order_type = order_type + def forward(self, grid, feats): for block in self.blocks: grid, feats = block(grid, feats) @@ -630,8 +664,8 @@ def __init__( cross_patch_attention (bool): Whether to use cross-patch attention. cross_patch_pooling (str): Pooling method for cross-patch attention ("mean" or "max"). sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). - order_type (Union[str, List[str]]): The type of order of the points. Can be a single string ("vdb", "z", "z-trans", "hilbert", "hilbert-trans") - for all layers, or a list of strings for different layers. Each encoder and decoder stage will use + order_type (Union[str, List[str]]): The type of order of the points. Can be a single string ("vdb", "z", "z-trans", "hilbert", "hilbert-trans") + for all layers, or a list of strings for different layers. Each encoder and decoder stage will use order_type[i % len(order_type)] where i is the stage index. """ super().__init__() @@ -642,7 +676,7 @@ def __init__( self.cross_patch_attention = cross_patch_attention self.cross_patch_pooling = cross_patch_pooling self.sliding_window_attention = sliding_window_attention - + # Handle order_type: convert to list for uniform processing if isinstance(order_type, str): self.order_type_list = [order_type] @@ -661,7 +695,9 @@ def __init__( ), "sliding_window_attention and cross_patch_attention should not be used together." if len(enc_channels) > 0: - self.embedding = PTV3_Embedding(input_dim, enc_channels[0], norm_layer_module=self.norm_layer, embedding_mode=embedding_mode) + self.embedding = PTV3_Embedding( + input_dim, enc_channels[0], norm_layer_module=self.norm_layer, embedding_mode=embedding_mode + ) else: self.embedding = None @@ -672,7 +708,12 @@ def __init__( for i in range(self.num_stages): if i > 0: self.enc.append( - PTV3_Pooling(kernel_size=2, in_channels=enc_channels[i - 1], out_channels=enc_channels[i], norm_layer_module=self.norm_layer) + PTV3_Pooling( + kernel_size=2, + in_channels=enc_channels[i - 1], + out_channels=enc_channels[i], + norm_layer_module=self.norm_layer, + ) ) # Select order_type for this encoder stage using modulo stage_order_type = self.order_type_list[i % len(self.order_type_list)] From 2a0e5a485254da7bf01f1046baa0a0e6ce9083bd Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Tue, 21 Oct 2025 18:17:25 -0700 Subject: [PATCH 05/10] remove the cross_patch_attention. Signed-off-by: Hexu Zhao --- point_transformer_v3/.gitignore | 1 + point_transformer_v3/model.py | 101 +------------------------------- 2 files changed, 2 insertions(+), 100 deletions(-) diff --git a/point_transformer_v3/.gitignore b/point_transformer_v3/.gitignore index 0099b23..5b5438c 100644 --- a/point_transformer_v3/.gitignore +++ b/point_transformer_v3/.gitignore @@ -8,3 +8,4 @@ fvdb-test-data !requirements.txt /tests /data +/__pycache__/ diff --git a/point_transformer_v3/model.py b/point_transformer_v3/model.py index 3f21f02..7d326f7 100644 --- a/point_transformer_v3/model.py +++ b/point_transformer_v3/model.py @@ -233,8 +233,6 @@ def __init__( num_heads: int, proj_drop: float = 0.0, patch_size: int = 0, - cross_patch_attention: bool = False, - cross_patch_pooling: str = "mean", sliding_window_attention: bool = False, order_type: str = "vdb", ): @@ -244,8 +242,6 @@ def __init__( num_heads (int): Number of attention heads in each block. proj_drop (float): Dropout rate for MLP layers. patch_size (int): Patch size for patch attention. - cross_patch_attention (bool): Whether to use cross-patch attention. - cross_patch_pooling (str): Pooling method for cross-patch attention ("mean" or "max"). sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). order_type (str): The type of order of the points, "vdb" or "z". """ @@ -260,8 +256,6 @@ def __init__( self.drop = torch.nn.Dropout(proj_drop) self.patch_size = patch_size self.order_type = order_type - self.cross_patch_attention = cross_patch_attention - self.cross_patch_pooling = cross_patch_pooling # "mean" or "max" # Sliding window attention parameter self.sliding_window_attention = sliding_window_attention @@ -290,8 +284,6 @@ def forward(self, grid, feats): qkv = self.qkv(feats_j) # (num_voxels, 3 * hidden_size) - # TODO: only keep the sliding window attention and the default window attention. - if self.sliding_window_attention and self.patch_size > 0: # Perform sliding window attention per-grid using flash attention num_voxels = feats_j.shape[0] @@ -350,71 +342,7 @@ def forward(self, grid, feats): qkv.half(), cu_seqlens, max_seqlen=self.patch_size, dropout_p=0.0, softmax_scale=1.0 ).reshape(num_voxels, self.hidden_size) - if self.cross_patch_attention: - # Apply cross-patch attention per-grid and add to feats_out_j - add_buf = torch.zeros_like(feats_out_j) - for b in range(offsets.numel() - 1): - start = int(offsets[b].item()) - end = int(offsets[b + 1].item()) - Li = end - start - if Li <= 0: - continue - num_complete_patches = Li // self.patch_size - remaining_voxels = Li % self.patch_size - - qkv_b = qkv[start:end] # (Li, 3, H, D) - if num_complete_patches > 0: - complete_voxels = num_complete_patches * self.patch_size - qkv_complete = qkv_b[:complete_voxels] - qkv_patches = qkv_complete.view(num_complete_patches, self.patch_size, 3, H, D) - if self.cross_patch_pooling == "mean": - qkv_pooled = qkv_patches.mean(dim=1) - elif self.cross_patch_pooling == "max": - qkv_pooled = qkv_patches.max(dim=1)[0] - else: - raise ValueError(f"Unsupported pooling method: {self.cross_patch_pooling}") - else: - qkv_pooled = None - complete_voxels = 0 - - num_total_patches = num_complete_patches - if remaining_voxels > 0: - qkv_remaining = qkv_b[complete_voxels:] - if self.cross_patch_pooling == "mean": - qkv_remaining_pooled = qkv_remaining.mean(dim=0, keepdim=True) - else: - qkv_remaining_pooled = qkv_remaining.max(dim=0, keepdim=True)[0] - qkv_pooled = ( - qkv_remaining_pooled - if qkv_pooled is None - else torch.cat([qkv_pooled, qkv_remaining_pooled], dim=0) - ) - num_total_patches += 1 - - if num_total_patches > 0: - # qkv_pooled must be defined here because num_total_patches > 0 - assert qkv_pooled is not None - cross_attn_out = flash_attn.flash_attn_qkvpacked_func( - qkv_pooled.unsqueeze(0).half(), dropout_p=0.0, softmax_scale=1.0 - ).reshape(num_total_patches, self.hidden_size) - - cross_attn_all_b = torch.zeros( - (Li, self.hidden_size), device=qkv.device, dtype=cross_attn_out.dtype - ) - if num_complete_patches > 0: - cross_attn_complete = cross_attn_out[:num_complete_patches] - cross_attn_expanded = cross_attn_complete.unsqueeze(1).expand(-1, self.patch_size, -1) - cross_attn_flat = cross_attn_expanded.reshape(complete_voxels, self.hidden_size) - cross_attn_all_b[:complete_voxels] = cross_attn_flat - if remaining_voxels > 0: - cross_attn_all_b[complete_voxels:] = ( - cross_attn_out[-1].unsqueeze(0).expand(remaining_voxels, -1) - ) - - add_buf[start:end] = cross_attn_all_b.to(add_buf.dtype) - - # Apply cross-patch addition across the whole batch - feats_out_j = (feats_out_j + add_buf.to(feats_out_j.dtype)).to(feats_j.dtype) + feats_out_j = feats_out_j.to(feats_j.dtype) else: feats_out_j = qkv[:, : self.hidden_size].contiguous() @@ -491,8 +419,6 @@ def __init__( proj_drop: float = 0.0, patch_size: int = 0, no_conv_in_cpe: bool = False, - cross_patch_attention: bool = False, - cross_patch_pooling: str = "mean", sliding_window_attention: bool = False, order_type: str = "vdb", ): @@ -504,8 +430,6 @@ def __init__( proj_drop (float): Dropout rate for MLP layers. patch_size (int): Patch size for patch attention. no_conv_in_cpe (bool): Whether to disable convolution in CPE. - cross_patch_attention (bool): Whether to use cross-patch attention. - cross_patch_pooling (str): Pooling method for cross-patch attention ("mean" or "max"). sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). order_type (str): The type of order of the points: "vdb", "z", "z-trans", "hilbert", "hilbert-trans". """ @@ -518,8 +442,6 @@ def __init__( num_heads, proj_drop, patch_size, - cross_patch_attention, - cross_patch_pooling, sliding_window_attention, order_type, ) @@ -567,8 +489,6 @@ def __init__( proj_drop: float = 0.0, patch_size: int = 0, no_conv_in_cpe: bool = False, - cross_patch_attention: bool = False, - cross_patch_pooling: str = "mean", sliding_window_attention: bool = False, order_type: str = "vdb", ): @@ -581,8 +501,6 @@ def __init__( proj_drop (float): Dropout rate for MLP layers. patch_size (int): Patch size for patch attention. no_conv_in_cpe (bool): Whether to disable convolution in CPE. - cross_patch_attention (bool): Whether to use cross-patch attention. - cross_patch_pooling (str): Pooling method for cross-patch attention ("mean" or "max"). sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). order_type (str): The type of order of the points, "vdb" or "z". """ @@ -597,8 +515,6 @@ def __init__( proj_drop, patch_size, no_conv_in_cpe, - cross_patch_attention, - cross_patch_pooling, sliding_window_attention, order_type, ) @@ -637,8 +553,6 @@ def __init__( enable_batch_norm: bool = False, embedding_mode: str = "linear", no_conv_in_cpe: bool = False, - cross_patch_attention: bool = False, - cross_patch_pooling: str = "mean", sliding_window_attention: bool = False, order_type: Union[str, List[str]] = "vdb", ) -> None: @@ -661,8 +575,6 @@ def __init__( enable_batch_norm (bool): Whether to use batch normalization for the embedding, down pooling, and up pooling. embedding_mode (bool): the mode for the embedding layer, "linear" or "conv3x3", "conv5x5". no_conv_in_cpe (bool): Whether to disable convolution in CPE. - cross_patch_attention (bool): Whether to use cross-patch attention. - cross_patch_pooling (str): Pooling method for cross-patch attention ("mean" or "max"). sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). order_type (Union[str, List[str]]): The type of order of the points. Can be a single string ("vdb", "z", "z-trans", "hilbert", "hilbert-trans") for all layers, or a list of strings for different layers. Each encoder and decoder stage will use @@ -673,8 +585,6 @@ def __init__( self.drop_path = drop_path self.proj_drop = proj_drop self.no_conv_in_cpe = no_conv_in_cpe - self.cross_patch_attention = cross_patch_attention - self.cross_patch_pooling = cross_patch_pooling self.sliding_window_attention = sliding_window_attention # Handle order_type: convert to list for uniform processing @@ -689,11 +599,6 @@ def __init__( else: self.norm_layer = partial(torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01) - # sliding_window_attention and cross_patch_attention should not be used together. - assert not ( - sliding_window_attention and cross_patch_attention - ), "sliding_window_attention and cross_patch_attention should not be used together." - if len(enc_channels) > 0: self.embedding = PTV3_Embedding( input_dim, enc_channels[0], norm_layer_module=self.norm_layer, embedding_mode=embedding_mode @@ -726,8 +631,6 @@ def __init__( proj_drop, patch_size, no_conv_in_cpe, - cross_patch_attention, - cross_patch_pooling, sliding_window_attention, stage_order_type, ) @@ -771,8 +674,6 @@ def __init__( proj_drop, patch_size, no_conv_in_cpe, - cross_patch_attention, - cross_patch_pooling, sliding_window_attention, stage_order_type, ) From 4b203f0f92260ac3635dca1311385d2d43a64f58 Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Tue, 21 Oct 2025 18:26:50 -0700 Subject: [PATCH 06/10] Clean the code. Signed-off-by: Hexu Zhao --- point_transformer_v3/model.py | 36 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/point_transformer_v3/model.py b/point_transformer_v3/model.py index 7d326f7..349224d 100644 --- a/point_transformer_v3/model.py +++ b/point_transformer_v3/model.py @@ -210,7 +210,7 @@ def __init__(self, hidden_size: int, proj_drop: float = 0.0): self.fc1 = torch.nn.Linear(hidden_size, hidden_size * 4) self.act = torch.nn.GELU() self.fc2 = torch.nn.Linear(hidden_size * 4, hidden_size) - self.drop = torch.nn.Dropout(proj_drop) # simplified setting: no dropout now. + self.drop = torch.nn.Dropout(proj_drop) def forward(self, grid, feats): nvtx.range_push("PTV3_MLP") @@ -301,7 +301,9 @@ def forward(self, grid, feats): window_size = (self.patch_size // 2, self.patch_size // 2) out_b = flash_attn.flash_attn_qkvpacked_func( qkv_b.half(), dropout_p=0.0, softmax_scale=1.0, window_size=window_size - ).reshape(Li, self.hidden_size) + ).reshape( + Li, self.hidden_size + ) # dtype: float16 outputs.append(out_b) if len(outputs) == 0: feats_out_j = torch.empty_like(qkv[:, : self.hidden_size]) @@ -340,15 +342,14 @@ def forward(self, grid, feats): feats_out_j = flash_attn.flash_attn_varlen_qkvpacked_func( qkv.half(), cu_seqlens, max_seqlen=self.patch_size, dropout_p=0.0, softmax_scale=1.0 - ).reshape(num_voxels, self.hidden_size) + ).reshape( + num_voxels, self.hidden_size + ) # dtype: float16 feats_out_j = feats_out_j.to(feats_j.dtype) else: feats_out_j = qkv[:, : self.hidden_size].contiguous() - # Ensure dtype matches original features before linear projection - feats_out_j = feats_out_j.to(feats_j.dtype) - if self.order_type != "vdb": perm_reverse = torch.empty_like(perm) perm_reverse[perm] = torch.arange(perm.shape[0], device=perm.device) # [num_voxels] @@ -374,16 +375,15 @@ def __init__(self, hidden_size: int, no_conv_in_cpe: bool = False): super().__init__() self.hidden_size = hidden_size self.no_conv_in_cpe = no_conv_in_cpe - # Wrap components in Sequential to match parameter naming convention self.cpe = torch.nn.ModuleList( [ ( fvdb.nn.SparseConv3d(hidden_size, hidden_size, kernel_size=3, stride=1) if not no_conv_in_cpe else torch.nn.Identity() - ), # cpe.0 - torch.nn.Linear(hidden_size, hidden_size), # cpe.1 - torch.nn.LayerNorm(hidden_size), # cpe.2 + ), + torch.nn.Linear(hidden_size, hidden_size), + torch.nn.LayerNorm(hidden_size), ] ) @@ -436,7 +436,7 @@ def __init__( super().__init__() self.cpe = PTV3_CPE(hidden_size, no_conv_in_cpe) - self.norm1 = torch.nn.Sequential(torch.nn.LayerNorm(hidden_size)) # norm1.0 + self.norm1 = torch.nn.LayerNorm(hidden_size) self.attn = PTV3_Attention( hidden_size, num_heads, @@ -445,7 +445,7 @@ def __init__( sliding_window_attention, order_type, ) - self.norm2 = torch.nn.Sequential(torch.nn.LayerNorm(hidden_size)) # norm2.0 + self.norm2 = torch.nn.LayerNorm(hidden_size) self.order_type = order_type self.mlp = PTV3_MLP(hidden_size, proj_drop) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity() @@ -599,12 +599,9 @@ def __init__( else: self.norm_layer = partial(torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01) - if len(enc_channels) > 0: - self.embedding = PTV3_Embedding( - input_dim, enc_channels[0], norm_layer_module=self.norm_layer, embedding_mode=embedding_mode - ) - else: - self.embedding = None + self.embedding = PTV3_Embedding( + input_dim, enc_channels[0], norm_layer_module=self.norm_layer, embedding_mode=embedding_mode + ) self.num_stages = len(enc_depths) if self.num_stages > 0: @@ -682,9 +679,6 @@ def __init__( def forward(self, grid, feats): nvtx.range_push("PTV3_Forward") - if self.embedding is None: - return grid, feats - grid, feats = self.embedding(grid, feats) layer_id = 0 From dd60e584523d9f06c0748c35e907fd2890287af5 Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Thu, 23 Oct 2025 02:59:46 +0000 Subject: [PATCH 07/10] move the permutation operation from fvdb.grid to point transformer v3. Signed-off-by: Hexu Zhao --- point_transformer_v3/model.py | 77 +++++++++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 4 deletions(-) diff --git a/point_transformer_v3/model.py b/point_transformer_v3/model.py index 349224d..60aa1cd 100644 --- a/point_transformer_v3/model.py +++ b/point_transformer_v3/model.py @@ -260,15 +260,84 @@ def __init__( # Sliding window attention parameter self.sliding_window_attention = sliding_window_attention + def _compute_permutation(self, grid, curve_codes): + """ + Get permutation indices to sort voxels by space-filling curve order. + + Takes pre-computed space-filling curve codes (e.g., from morton(), hilbert(), etc.) + and returns the permutation indices that would sort voxels according to those codes. + This is useful for spatially coherent data access patterns and cache optimization. + + Args: + grid: The grid batch containing voxel information. + curve_codes (JaggedTensor): Space-filling curve codes for each voxel. + Shape: `[num_grids, -1, 1]`. Typically obtained from morton(), morton_zyx(), + hilbert(), or hilbert_zyx() methods. + + Returns: + JaggedTensor: A JaggedTensor of shape `[num_grids, -1, 1]` containing + the permutation indices. Use these indices to reorder voxel data for spatial coherence. + """ + # Get the curve codes as a flat tensor + curve_data = curve_codes.jdata.squeeze(-1) # Shape: [total_voxels] + + # Create output tensor for permutation indices + permutation_indices = torch.empty_like(curve_data, dtype=torch.long) + + # Sort curve codes and get permutation indices for each grid + offset = 0 + for grid_idx in range(grid.grid_count): + num_voxels = grid.num_voxels_at(grid_idx) + if num_voxels == 0: + continue + + # Extract curve codes for this grid + grid_curve_codes = curve_data[offset : offset + num_voxels] + + # Sort and get indices + _, indices = torch.sort(grid_curve_codes, dim=0) + + # Store indices with offset + permutation_indices[offset : offset + num_voxels] = indices + offset + + offset += num_voxels + + # Return as JaggedTensor with the same structure as the input + return grid.jagged_like(permutation_indices.unsqueeze(-1)) + + def _permutation_morton(self, grid): + """ + Return permutation indices to sort voxels by Morton curve order. + """ + return self._compute_permutation(grid, grid.morton()) + + def _permutation_morton_zyx(self, grid): + """ + Return permutation indices to sort voxels by transposed Morton curve order. + """ + return self._compute_permutation(grid, grid.morton_zyx()) + + def _permutation_hilbert(self, grid): + """ + Return permutation indices to sort voxels by Hilbert curve order. + """ + return self._compute_permutation(grid, grid.hilbert()) + + def _permutation_hilbert_zyx(self, grid): + """ + Return permutation indices to sort voxels by transposed Hilbert curve order. + """ + return self._compute_permutation(grid, grid.hilbert_zyx()) + def _permute(self, grid, order_type): if order_type == "z": - return grid.permutation_morton() + return self._permutation_morton(grid) elif order_type == "z-trans": - return grid.permutation_morton_zyx() + return self._permutation_morton_zyx(grid) elif order_type == "hilbert": - return grid.permutation_hilbert() + return self._permutation_hilbert(grid) elif order_type == "hilbert-trans": - return grid.permutation_hilbert_zyx() + return self._permutation_hilbert_zyx(grid) else: raise ValueError(f"Unsupported order type: {order_type}") From af5e73e344d48a61df286ca2cb5e028a3511cb7c Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Fri, 31 Oct 2025 17:50:36 +0000 Subject: [PATCH 08/10] update to new fvdb api. TODO: fix bugs. Signed-off-by: Hexu Zhao --- point_transformer_v3/minimal_inference.py | 8 +- point_transformer_v3/model.py | 113 ++++++++++++++-------- point_transformer_v3/requirements.txt | 2 +- 3 files changed, 82 insertions(+), 41 deletions(-) diff --git a/point_transformer_v3/minimal_inference.py b/point_transformer_v3/minimal_inference.py index 8946e28..65a8db3 100644 --- a/point_transformer_v3/minimal_inference.py +++ b/point_transformer_v3/minimal_inference.py @@ -122,6 +122,8 @@ def create_ptv3_model(args, device, num_classes): patch_size=args.patch_size, proj_drop=0.0, drop_path=0.3, + # no_conv_in_cpe=True, + # embedding_mode="linear", ).to(device) return model @@ -150,8 +152,12 @@ def prepare_batched_inputs_from_scannet_points(batch_samples, voxel_size=0.1, de color_jagged = fvdb.JaggedTensor(colors_list) color_vdb_order = grid.inject_from_ijk(coords_jagged, color_jagged) jfeats = color_vdb_order.jdata - jfeats = fvdb.jcat([grid.ijk.float(), jfeats], dim=1) + # jfeats = fvdb.jcat([grid.ijk.float(), jfeats], dim=1) + jfeats = fvdb.jcat([grid.ijk.float(), color_vdb_order], dim=1) return grid, jfeats + # import pdb; pdb.set_trace() + # jfeats = torch.cat([grid.ijk.float(), jfeats], dim=1) + # return grid, fvdb.JaggedTensor(jfeats) def main(): diff --git a/point_transformer_v3/model.py b/point_transformer_v3/model.py index 60aa1cd..70cf24c 100644 --- a/point_transformer_v3/model.py +++ b/point_transformer_v3/model.py @@ -41,6 +41,7 @@ def __init__( embed_channels, norm_layer_module: torch.nn.Module = torch.nn.LayerNorm, embedding_mode: str = "linear", + shared_plan_cache: Dict = None, ): """ Args: @@ -48,9 +49,11 @@ def __init__( embed_channels (int): Number of channels in the output features. norm_layer_module (torch.nn.Module): Normalization layer module. embedding_mode (str): The type of embedding layer, "linear" or "conv3x3", "conv5x5". + shared_plan_cache (Dict): Shared cache for ConvolutionPlans across all layers. """ super().__init__() self.embedding_mode = embedding_mode + self.shared_plan_cache = shared_plan_cache if embedding_mode == "linear": self.embed = torch.nn.Linear(in_channels, embed_channels) @@ -80,34 +83,43 @@ def __init__( self.norm_layer = norm_layer_module(embed_channels) self.act_layer = torch.nn.GELU() + def _get_plan(self, grid, kernel_size, stride): + """Get or create a ConvolutionPlan from shared cache.""" + cache_key = (grid.address, kernel_size, stride) + if cache_key not in self.shared_plan_cache: + self.shared_plan_cache[cache_key] = fvdb.ConvolutionPlan.from_grid_batch( + kernel_size=kernel_size, + stride=stride, + source_grid=grid, + target_grid=grid + ) + return self.shared_plan_cache[cache_key] + def forward(self, grid, feats): nvtx.range_push("PTV3_Embedding") - # Initialize kernel map (kmap) for sparse convolution operations - # kmap tracks the mapping between input and output features during sparse convolutions - if not hasattr(grid, "kmap"): - grid.kmap = None - if self.embedding_mode == "linear": jfeats = feats.jdata jfeats = self.embed(jfeats) elif self.embedding_mode == "conv3x3": - # Apply 3x3 sparse convolution while maintaining kernel mapping - # Note: Bias is intentionally disabled to maintain consistency with standard transformer architectures - grid, feats, out_kmap = self.embed_conv3x3_1._dispatch_conv(feats, grid, grid.kmap, grid) - grid.kmap = out_kmap + # Apply 3x3 sparse convolution using shared ConvolutionPlan cache + plan = self._get_plan(grid, kernel_size=3, stride=1) + feats = self.embed_conv3x3_1(feats, plan) jfeats = feats.jdata elif self.embedding_mode == "conv5x5": - grid, feats, out_kmap = self.embed_conv3x3_1._dispatch_conv(feats, grid, grid.kmap, grid) - grid.kmap = out_kmap - grid, feats, out_kmap = self.embed_conv3x3_2._dispatch_conv(feats, grid, grid.kmap, grid) - grid.kmap = out_kmap + # First 3x3 convolution + plan1 = self._get_plan(grid, kernel_size=3, stride=1) + feats = self.embed_conv3x3_1(feats, plan1) + + # Second 3x3 convolution (same grid since stride=1, in-place) + plan2 = self._get_plan(grid, kernel_size=3, stride=1) + feats = self.embed_conv3x3_2(feats, plan2) jfeats = feats.jdata jfeats = self.norm_layer(jfeats) jfeats = self.act_layer(jfeats) - feats = feats.jagged_like(jfeats) + feats = grid.jagged_like(jfeats) nvtx.range_pop() return grid, feats @@ -135,15 +147,14 @@ def __init__( def forward(self, grid, feats): nvtx.range_push("PTV3_Pooling") feats_j = self.proj(feats.jdata) - feats = feats.jagged_like(feats_j) + feats = grid.jagged_like(feats_j) ds_feature, ds_grid = grid.max_pool(self.kernel_size, feats, stride=self.kernel_size, coarse_grid=None) ds_feature_j = ds_feature.jdata ds_feature_j = self.norm_layer(ds_feature_j) ds_feature_j = self.act_layer(ds_feature_j) - ds_feature = ds_feature.jagged_like(ds_feature_j) + ds_feature = ds_grid.jagged_like(ds_feature_j) nvtx.range_pop() - ds_grid.kmap = None return ds_grid, ds_feature @@ -188,13 +199,10 @@ def forward(self, grid, feats, last_grid, last_feats): last_feats_j = self.norm_skip(last_feats_j) last_feats_j = self.act_layer_skip(last_feats_j) - feats, _ = grid.subdivide(self.kernel_size, grid.jagged_like(feats_j), fine_grid=last_grid) + feats, _ = grid.refine(self.kernel_size, grid.jagged_like(feats_j), fine_grid=last_grid) feats_j = feats.jdata new_feats_j = last_feats_j + feats_j - last_grid.kmap = ( - None # Because of the pooling operation, the previous kmap for convolution is not valid anymore. - ) return last_grid, last_grid.jagged_like(new_feats_j) @@ -221,7 +229,7 @@ def forward(self, grid, feats): feats_j = self.drop(feats_j) feats_j = self.fc2(feats_j) feats_j = self.drop(feats_j) - feats = feats.jagged_like(feats_j) + feats = grid.jagged_like(feats_j) nvtx.range_pop() return grid, feats @@ -435,15 +443,17 @@ def forward(self, grid, feats): class PTV3_CPE(torch.nn.Module): - def __init__(self, hidden_size: int, no_conv_in_cpe: bool = False): + def __init__(self, hidden_size: int, no_conv_in_cpe: bool = False, shared_plan_cache: Dict = None): """ Args: hidden_size (int): Number of channels in the input features. no_conv_in_cpe (bool): Whether to disable convolution in CPE. + shared_plan_cache (Dict): Shared cache for ConvolutionPlans across all layers. """ super().__init__() self.hidden_size = hidden_size self.no_conv_in_cpe = no_conv_in_cpe + self.shared_plan_cache = shared_plan_cache self.cpe = torch.nn.ModuleList( [ ( @@ -456,17 +466,26 @@ def __init__(self, hidden_size: int, no_conv_in_cpe: bool = False): ] ) + def _get_plan(self, grid, kernel_size, stride): + """Get or create a ConvolutionPlan from shared cache.""" + cache_key = (grid.address, kernel_size, stride) + if cache_key not in self.shared_plan_cache: + self.shared_plan_cache[cache_key] = fvdb.ConvolutionPlan.from_grid_batch( + kernel_size=kernel_size, + stride=stride, + source_grid=grid, + target_grid=grid + ) + return self.shared_plan_cache[cache_key] + def forward(self, grid, feats): nvtx.range_push("PTV3_CPE") - if not hasattr(grid, "kmap"): - grid.kmap = None - if not self.no_conv_in_cpe: - grid, out_feature, out_kmap = self.cpe[0]._dispatch_conv(feats, grid, grid.kmap, grid) - grid.kmap = out_kmap # update the kmap - if self.cpe[0].bias is not None: - out_feature.jdata = out_feature.jdata + self.cpe[0].bias + # Apply 3x3 sparse convolution using shared ConvolutionPlan cache + plan = self._get_plan(grid, kernel_size=3, stride=1) + out_feature = self.cpe[0](feats, plan) + # Note: bias is already handled inside SparseConv3d.forward() else: out_feature = feats @@ -490,6 +509,7 @@ def __init__( no_conv_in_cpe: bool = False, sliding_window_attention: bool = False, order_type: str = "vdb", + shared_plan_cache: Dict = None, ): """ Args: @@ -501,10 +521,11 @@ def __init__( no_conv_in_cpe (bool): Whether to disable convolution in CPE. sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). order_type (str): The type of order of the points: "vdb", "z", "z-trans", "hilbert", "hilbert-trans". + shared_plan_cache (Dict): Shared cache for ConvolutionPlans across all layers. """ super().__init__() - self.cpe = PTV3_CPE(hidden_size, no_conv_in_cpe) + self.cpe = PTV3_CPE(hidden_size, no_conv_in_cpe, shared_plan_cache) self.norm1 = torch.nn.LayerNorm(hidden_size) self.attn = PTV3_Attention( hidden_size, @@ -522,27 +543,27 @@ def __init__( def forward(self, grid, feats): nvtx.range_push("PTV3_Block") grid, feats_out = self.cpe(grid, feats) - feats = feats.jagged_like(feats.jdata + feats_out.jdata) + feats = grid.jagged_like(feats.jdata + feats_out.jdata) short_cut = feats.jdata - feats = feats.jagged_like(self.norm1(feats.jdata)) + feats = grid.jagged_like(self.norm1(feats.jdata)) grid, feats_out = self.attn(grid, feats) - feats_out = feats.jagged_like( + feats_out = grid.jagged_like( self.drop_path(feats_out.jdata) ) # This drop_path is applied to each point independently. - feats = feats.jagged_like(short_cut + feats_out.jdata) + feats = grid.jagged_like(short_cut + feats_out.jdata) short_cut = feats.jdata - feats = feats.jagged_like(self.norm2(feats.jdata)) + feats = grid.jagged_like(self.norm2(feats.jdata)) grid, feats_out = self.mlp(grid, feats) - feats_out = feats.jagged_like( + feats_out = grid.jagged_like( self.drop_path(feats_out.jdata) ) # This drop_path is applied to each point independently. - feats = feats.jagged_like(short_cut + feats_out.jdata) + feats = grid.jagged_like(short_cut + feats_out.jdata) nvtx.range_pop() return grid, feats @@ -560,6 +581,7 @@ def __init__( no_conv_in_cpe: bool = False, sliding_window_attention: bool = False, order_type: str = "vdb", + shared_plan_cache: Dict = None, ): """ Args: @@ -572,6 +594,7 @@ def __init__( no_conv_in_cpe (bool): Whether to disable convolution in CPE. sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). order_type (str): The type of order of the points, "vdb" or "z". + shared_plan_cache (Dict): Shared cache for ConvolutionPlans across all layers. """ super().__init__() self.depth = depth @@ -586,6 +609,7 @@ def __init__( no_conv_in_cpe, sliding_window_attention, order_type, + shared_plan_cache, ) for i in range(depth) ] @@ -668,8 +692,13 @@ def __init__( else: self.norm_layer = partial(torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01) + # Shared ConvolutionPlan cache across all layers to avoid redundant computation. + # Cache is cleared at the end of each forward pass to prevent OOM. + self.shared_plan_cache = {} + self.embedding = PTV3_Embedding( - input_dim, enc_channels[0], norm_layer_module=self.norm_layer, embedding_mode=embedding_mode + input_dim, enc_channels[0], norm_layer_module=self.norm_layer, + embedding_mode=embedding_mode, shared_plan_cache=self.shared_plan_cache ) self.num_stages = len(enc_depths) @@ -699,6 +728,7 @@ def __init__( no_conv_in_cpe, sliding_window_attention, stage_order_type, + self.shared_plan_cache, ) ) @@ -742,6 +772,7 @@ def __init__( no_conv_in_cpe, sliding_window_attention, stage_order_type, + self.shared_plan_cache, ) ) @@ -778,5 +809,9 @@ def forward(self, grid, feats): nvtx.range_pop() layer_id += 1 + # Clear cache after forward pass to prevent OOM between batches + # Plans are shared across layers during this forward pass, but won't be needed for next batch + self.shared_plan_cache.clear() + nvtx.range_pop() return grid, feats diff --git a/point_transformer_v3/requirements.txt b/point_transformer_v3/requirements.txt index 91b2155..2c18427 100644 --- a/point_transformer_v3/requirements.txt +++ b/point_transformer_v3/requirements.txt @@ -1,2 +1,2 @@ -flash_attn==2.7.2.post1 +flash-attn==2.7.4.post1 timm From cc5ca3fd939e9075e49b2ea8fad503a9e2fc1757 Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Fri, 7 Nov 2025 19:43:41 +0000 Subject: [PATCH 09/10] Fix two previous gaps from pointcept: (1) Different orders shuffle for different point cloud inputs. (2) add attention softmax scaling Signed-off-by: Hexu Zhao --- point_transformer_v3/minimal_inference.py | 1 + point_transformer_v3/model.py | 147 ++++++++++++++++------ 2 files changed, 107 insertions(+), 41 deletions(-) diff --git a/point_transformer_v3/minimal_inference.py b/point_transformer_v3/minimal_inference.py index 65a8db3..216da66 100644 --- a/point_transformer_v3/minimal_inference.py +++ b/point_transformer_v3/minimal_inference.py @@ -122,6 +122,7 @@ def create_ptv3_model(args, device, num_classes): patch_size=args.patch_size, proj_drop=0.0, drop_path=0.3, + order_type=("z", "z-trans", "hilbert", "hilbert-trans"), # no_conv_in_cpe=True, # embedding_mode="linear", ).to(device) diff --git a/point_transformer_v3/model.py b/point_transformer_v3/model.py index 70cf24c..bf8a2c5 100644 --- a/point_transformer_v3/model.py +++ b/point_transformer_v3/model.py @@ -222,7 +222,7 @@ def __init__(self, hidden_size: int, proj_drop: float = 0.0): def forward(self, grid, feats): nvtx.range_push("PTV3_MLP") - feats_j = feats.jdata + feats_j = feats.jdata # TODO: deprecate the .jdata usage. feats_j = self.fc1(feats_j) feats_j = self.act(feats_j) @@ -241,8 +241,10 @@ def __init__( num_heads: int, proj_drop: float = 0.0, patch_size: int = 0, + qk_scale: float = None, sliding_window_attention: bool = False, - order_type: str = "vdb", + order_index: int = 0, + order_types: tuple = ("vdb",), ): """ Args: @@ -250,8 +252,10 @@ def __init__( num_heads (int): Number of attention heads in each block. proj_drop (float): Dropout rate for MLP layers. patch_size (int): Patch size for patch attention. + qk_scale (float): Scale factor for query-key dot product. If None, uses 1/sqrt(head_dim). sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). - order_type (str): The type of order of the points, "vdb" or "z". + order_index (int): Index into order_types to select which order to use for this block. + order_types (tuple): Tuple of order type strings (e.g., ("z", "z-trans")). """ super().__init__() self.hidden_size = hidden_size @@ -259,11 +263,15 @@ def __init__( self.head_dim = hidden_size // num_heads assert self.head_dim * num_heads == hidden_size, "hidden_size must be divisible by num_heads" + self.scale = qk_scale or (self.head_dim) ** -0.5 self.qkv = torch.nn.Linear(hidden_size, hidden_size * 3) # Combined QKV projection self.proj = torch.nn.Linear(hidden_size, hidden_size) self.drop = torch.nn.Dropout(proj_drop) self.patch_size = patch_size - self.order_type = order_type + self.order_index = order_index + self.order_types = order_types + + # TODO: Add attention dropout # Sliding window attention parameter self.sliding_window_attention = sliding_window_attention @@ -353,12 +361,21 @@ def forward(self, grid, feats): nvtx.range_push("PTV3_Attention") feats_j = feats.jdata - if self.order_type != "vdb": - perm = self._permute(grid, self.order_type).jdata.squeeze(-1) # [num_voxels] + # Get the shuffled order from grid metadata if available, otherwise use default order_types + # This allows for order shuffling per forward pass (matching reference implementation) + active_order_types = grid._shuffled_order + + # Get the order type for this block using the order index + order_type = active_order_types[self.order_index % len(active_order_types)] + + if order_type != "vdb": + perm = self._permute(grid, order_type).jdata.squeeze(-1) # [num_voxels] # Use torch.gather for permutation: expand perm to match feats_j dimensions perm_expanded = perm.unsqueeze(-1).expand(-1, feats_j.shape[-1]) # [num_voxels, hidden_size] feats_j = torch.gather(feats_j, 0, perm_expanded) + # import pdb; pdb.set_trace() + qkv = self.qkv(feats_j) # (num_voxels, 3 * hidden_size) if self.sliding_window_attention and self.patch_size > 0: @@ -377,7 +394,10 @@ def forward(self, grid, feats): qkv_b = qkv[start:end].view(1, Li, 3, H, D) window_size = (self.patch_size // 2, self.patch_size // 2) out_b = flash_attn.flash_attn_qkvpacked_func( - qkv_b.half(), dropout_p=0.0, softmax_scale=1.0, window_size=window_size + qkv_b.half(), + dropout_p=0.0, + softmax_scale=self.scale, + window_size=window_size ).reshape( Li, self.hidden_size ) # dtype: float16 @@ -418,7 +438,11 @@ def forward(self, grid, feats): cu_seqlens[1:] = torch.as_tensor(lengths, device=qkv.device, dtype=torch.int32).cumsum(dim=0) feats_out_j = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv.half(), cu_seqlens, max_seqlen=self.patch_size, dropout_p=0.0, softmax_scale=1.0 + qkv.half(), + cu_seqlens, + max_seqlen=self.patch_size, + dropout_p=0.0, # TODO: implement attention dropout in the future. By default, it is 0. + softmax_scale=self.scale ).reshape( num_voxels, self.hidden_size ) # dtype: float16 @@ -427,7 +451,7 @@ def forward(self, grid, feats): else: feats_out_j = qkv[:, : self.hidden_size].contiguous() - if self.order_type != "vdb": + if order_type != "vdb": perm_reverse = torch.empty_like(perm) perm_reverse[perm] = torch.arange(perm.shape[0], device=perm.device) # [num_voxels] perm_reverse_expanded = perm_reverse.unsqueeze(-1).expand( @@ -457,7 +481,7 @@ def __init__(self, hidden_size: int, no_conv_in_cpe: bool = False, shared_plan_c self.cpe = torch.nn.ModuleList( [ ( - fvdb.nn.SparseConv3d(hidden_size, hidden_size, kernel_size=3, stride=1) + fvdb.nn.SparseConv3d(hidden_size, hidden_size, kernel_size=3, stride=1) # by default, bias is True. if not no_conv_in_cpe else torch.nn.Identity() ), @@ -506,9 +530,11 @@ def __init__( drop_path: float, proj_drop: float = 0.0, patch_size: int = 0, + qk_scale: float = None, no_conv_in_cpe: bool = False, sliding_window_attention: bool = False, - order_type: str = "vdb", + order_index: int = 0, + order_types: tuple = ("vdb",), shared_plan_cache: Dict = None, ): """ @@ -518,9 +544,11 @@ def __init__( drop_path (float): Drop path rate for regularization. proj_drop (float): Dropout rate for MLP layers. patch_size (int): Patch size for patch attention. + qk_scale (float): Scale factor for query-key dot product. If None, uses 1/sqrt(head_dim). no_conv_in_cpe (bool): Whether to disable convolution in CPE. sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). - order_type (str): The type of order of the points: "vdb", "z", "z-trans", "hilbert", "hilbert-trans". + order_index (int): Index into order_types to select which order to use for this block. + order_types (tuple): Tuple of order type strings (e.g., ("z", "z-trans")). shared_plan_cache (Dict): Shared cache for ConvolutionPlans across all layers. """ super().__init__() @@ -532,18 +560,20 @@ def __init__( num_heads, proj_drop, patch_size, + qk_scale, sliding_window_attention, - order_type, + order_index, + order_types, ) self.norm2 = torch.nn.LayerNorm(hidden_size) - self.order_type = order_type + self.order_index = order_index self.mlp = PTV3_MLP(hidden_size, proj_drop) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity() def forward(self, grid, feats): nvtx.range_push("PTV3_Block") grid, feats_out = self.cpe(grid, feats) - feats = grid.jagged_like(feats.jdata + feats_out.jdata) + feats = grid.jagged_like(feats.jdata + feats_out.jdata) # Is this a potential issue? short_cut = feats.jdata feats = grid.jagged_like(self.norm1(feats.jdata)) @@ -578,9 +608,10 @@ def __init__( drop_path, # drop_path is a list of drop path rates for each block. proj_drop: float = 0.0, patch_size: int = 0, + qk_scale: float = None, no_conv_in_cpe: bool = False, sliding_window_attention: bool = False, - order_type: str = "vdb", + order_types: tuple = ("vdb",), shared_plan_cache: Dict = None, ): """ @@ -591,9 +622,10 @@ def __init__( drop_path (list): Drop path rates for each block. proj_drop (float): Dropout rate for MLP layers. patch_size (int): Patch size for patch attention. + qk_scale (float): Scale factor for query-key dot product. If None, uses 1/sqrt(head_dim). no_conv_in_cpe (bool): Whether to disable convolution in CPE. sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). - order_type (str): The type of order of the points, "vdb" or "z". + order_types (tuple): Tuple of order type strings (e.g., ("z", "z-trans")). shared_plan_cache (Dict): Shared cache for ConvolutionPlans across all layers. """ super().__init__() @@ -606,15 +638,17 @@ def __init__( drop_path[i], proj_drop, patch_size, + qk_scale, no_conv_in_cpe, sliding_window_attention, - order_type, + i % len(order_types), # order_index cycles through available order types + order_types, shared_plan_cache, ) for i in range(depth) ] ) - self.order_type = order_type + self.order_types = order_types def forward(self, grid, feats): for block in self.blocks: @@ -643,11 +677,13 @@ def __init__( patch_size: int = 0, drop_path: float = 0.3, proj_drop: float = 0.0, + qk_scale: float = None, enable_batch_norm: bool = False, embedding_mode: str = "linear", no_conv_in_cpe: bool = False, sliding_window_attention: bool = False, - order_type: Union[str, List[str]] = "vdb", + order_type: Union[str, tuple] = ("z", "z-trans"), + shuffle_orders: bool = True, ) -> None: """ ptv3 for 3D point cloud segmentation. @@ -665,27 +701,26 @@ def __init__( patch_size (int): Patch size for patch attention. drop_path (float): Drop path rate for regularization. proj_drop (float): Dropout rate for MLP layers. + qk_scale (float): Scale factor for query-key dot product. If None, uses 1/sqrt(head_dim). enable_batch_norm (bool): Whether to use batch normalization for the embedding, down pooling, and up pooling. embedding_mode (bool): the mode for the embedding layer, "linear" or "conv3x3", "conv5x5". no_conv_in_cpe (bool): Whether to disable convolution in CPE. sliding_window_attention (bool): Whether to use sliding window attention (uses patch_size as window size). - order_type (Union[str, List[str]]): The type of order of the points. Can be a single string ("vdb", "z", "z-trans", "hilbert", "hilbert-trans") - for all layers, or a list of strings for different layers. Each encoder and decoder stage will use - order_type[i % len(order_type)] where i is the stage index. + order (Union[str, tuple]): The type(s) of point ordering. Can be a single string ("vdb", "z", "z-trans", "hilbert", "hilbert-trans") + or a tuple of strings (e.g., ("z", "z-trans")). Each block within a stage cycles through the order types. + shuffle_orders (bool): Whether to shuffle the order of order types at the beginning of each forward pass and after each pooling. """ super().__init__() self.num_classes = num_classes self.drop_path = drop_path self.proj_drop = proj_drop + self.qk_scale = qk_scale self.no_conv_in_cpe = no_conv_in_cpe self.sliding_window_attention = sliding_window_attention + self.shuffle_orders = shuffle_orders - # Handle order_type: convert to list for uniform processing - if isinstance(order_type, str): - self.order_type_list = [order_type] - else: - self.order_type_list = order_type - self.order_type = order_type # Keep original for backward compatibility + # Handle order: convert to tuple for uniform processing (matching reference implementation) + self.order_type = tuple([order_type]) if isinstance(order_type, str) else tuple(order_type) if not enable_batch_norm: self.norm_layer = torch.nn.LayerNorm @@ -713,10 +748,9 @@ def __init__( in_channels=enc_channels[i - 1], out_channels=enc_channels[i], norm_layer_module=self.norm_layer, - ) ) - # Select order_type for this encoder stage using modulo - stage_order_type = self.order_type_list[i % len(self.order_type_list)] + ) + # All encoder stages share the same order types; blocks within each stage cycle through them self.enc.append( PTV3_Encoder( enc_channels[i], @@ -725,9 +759,10 @@ def __init__( enc_drop_path[sum(enc_depths[:i]) : sum(enc_depths[: i + 1])], proj_drop, patch_size, + qk_scale, no_conv_in_cpe, sliding_window_attention, - stage_order_type, + self.order_type, self.shared_plan_cache, ) ) @@ -757,10 +792,7 @@ def __init__( norm_layer_module=self.norm_layer, ) ) - # Select order_type for this decoder stage using modulo - # Use reverse order for decoder (from last encoder stage backwards) - dec_stage_idx = self.num_stages - 1 - i - stage_order_type = self.order_type_list[dec_stage_idx % len(self.order_type_list)] + # All decoder stages share the same order types; blocks within each stage cycle through them self.dec.append( PTV3_Encoder( dec_channels[i], @@ -769,25 +801,50 @@ def __init__( dec_drop_path_, proj_drop, patch_size, + qk_scale, no_conv_in_cpe, sliding_window_attention, - stage_order_type, + self.order_type, self.shared_plan_cache, ) ) + def _shuffle_order(self): + """ + Randomly shuffle the order tuple to create variation across forward passes. + Returns a new shuffled tuple of order types. + """ + if self.shuffle_orders: + indices = torch.randperm(len(self.order_type)) + return tuple(self.order_type[i] for i in indices) + else: + return self.order_type + def forward(self, grid, feats): nvtx.range_push("PTV3_Forward") + # Shuffle order at the beginning of forward pass (matching reference implementation) + shuffled_order = self._shuffle_order() + + # Store shuffled order in grid metadata so all blocks can access it + grid._shuffled_order = shuffled_order + grid, feats = self.embedding(grid, feats) layer_id = 0 - stack = [] + stack = [] # Stack stores (grid, feats, shuffled_order) tuples for i in range(self.num_stages): if i > 0: nvtx.range_push(f"PTV3_Pooling_{layer_id}") - stack.append((grid, feats)) + # Push grid, feats, AND the current shuffled_order to stack + # The decoder will reuse this exact shuffled order for the corresponding stage + stack.append((grid, feats, shuffled_order)) grid, feats = self.enc[layer_id](grid, feats) + + # Shuffle order after pooling for the next (downsampled) stage + shuffled_order = self._shuffle_order() + grid._shuffled_order = shuffled_order + nvtx.range_pop() layer_id += 1 nvtx.range_push(f"PTV3_Encoder_{layer_id}") @@ -799,12 +856,20 @@ def forward(self, grid, feats): layer_id = 0 for i in range(self.num_dec_stages): nvtx.range_push(f"PTV3_Unpooling_{layer_id}") - last_grid, last_feats = stack.pop() + # Pop grid, feats, AND the shuffled_order from the corresponding encoder stage + last_grid, last_feats, last_shuffled_order = stack.pop() + + # Restore the shuffled order from the encoder stage to the grids + # This ensures decoder blocks use the SAME order as the corresponding encoder blocks + last_grid._shuffled_order = last_shuffled_order + grid, feats = self.dec[layer_id](grid, feats, last_grid, last_feats) + # After unpooling, grid becomes last_grid with the restored shuffled order nvtx.range_pop() layer_id += 1 nvtx.range_push(f"PTV3_Decoder_{layer_id}") + # Decoder blocks use grid with the restored shuffled order from encoder grid, feats = self.dec[layer_id](grid, feats) nvtx.range_pop() layer_id += 1 From e2437bf3c9ca05ef837c772dd2f5505beb7f968f Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Fri, 7 Nov 2025 19:47:29 +0000 Subject: [PATCH 10/10] refactor. Signed-off-by: Hexu Zhao --- point_transformer_v3/model.py | 56 ++++++++++++++++------------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/point_transformer_v3/model.py b/point_transformer_v3/model.py index bf8a2c5..7295bea 100644 --- a/point_transformer_v3/model.py +++ b/point_transformer_v3/model.py @@ -88,10 +88,7 @@ def _get_plan(self, grid, kernel_size, stride): cache_key = (grid.address, kernel_size, stride) if cache_key not in self.shared_plan_cache: self.shared_plan_cache[cache_key] = fvdb.ConvolutionPlan.from_grid_batch( - kernel_size=kernel_size, - stride=stride, - source_grid=grid, - target_grid=grid + kernel_size=kernel_size, stride=stride, source_grid=grid, target_grid=grid ) return self.shared_plan_cache[cache_key] @@ -110,7 +107,7 @@ def forward(self, grid, feats): # First 3x3 convolution plan1 = self._get_plan(grid, kernel_size=3, stride=1) feats = self.embed_conv3x3_1(feats, plan1) - + # Second 3x3 convolution (same grid since stride=1, in-place) plan2 = self._get_plan(grid, kernel_size=3, stride=1) feats = self.embed_conv3x3_2(feats, plan2) @@ -222,7 +219,7 @@ def __init__(self, hidden_size: int, proj_drop: float = 0.0): def forward(self, grid, feats): nvtx.range_push("PTV3_MLP") - feats_j = feats.jdata # TODO: deprecate the .jdata usage. + feats_j = feats.jdata # TODO: deprecate the .jdata usage. feats_j = self.fc1(feats_j) feats_j = self.act(feats_j) @@ -364,10 +361,10 @@ def forward(self, grid, feats): # Get the shuffled order from grid metadata if available, otherwise use default order_types # This allows for order shuffling per forward pass (matching reference implementation) active_order_types = grid._shuffled_order - + # Get the order type for this block using the order index order_type = active_order_types[self.order_index % len(active_order_types)] - + if order_type != "vdb": perm = self._permute(grid, order_type).jdata.squeeze(-1) # [num_voxels] # Use torch.gather for permutation: expand perm to match feats_j dimensions @@ -394,10 +391,7 @@ def forward(self, grid, feats): qkv_b = qkv[start:end].view(1, Li, 3, H, D) window_size = (self.patch_size // 2, self.patch_size // 2) out_b = flash_attn.flash_attn_qkvpacked_func( - qkv_b.half(), - dropout_p=0.0, - softmax_scale=self.scale, - window_size=window_size + qkv_b.half(), dropout_p=0.0, softmax_scale=self.scale, window_size=window_size ).reshape( Li, self.hidden_size ) # dtype: float16 @@ -438,11 +432,11 @@ def forward(self, grid, feats): cu_seqlens[1:] = torch.as_tensor(lengths, device=qkv.device, dtype=torch.int32).cumsum(dim=0) feats_out_j = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv.half(), - cu_seqlens, - max_seqlen=self.patch_size, - dropout_p=0.0, # TODO: implement attention dropout in the future. By default, it is 0. - softmax_scale=self.scale + qkv.half(), + cu_seqlens, + max_seqlen=self.patch_size, + dropout_p=0.0, # TODO: implement attention dropout in the future. By default, it is 0. + softmax_scale=self.scale, ).reshape( num_voxels, self.hidden_size ) # dtype: float16 @@ -481,7 +475,7 @@ def __init__(self, hidden_size: int, no_conv_in_cpe: bool = False, shared_plan_c self.cpe = torch.nn.ModuleList( [ ( - fvdb.nn.SparseConv3d(hidden_size, hidden_size, kernel_size=3, stride=1) # by default, bias is True. + fvdb.nn.SparseConv3d(hidden_size, hidden_size, kernel_size=3, stride=1) # by default, bias is True. if not no_conv_in_cpe else torch.nn.Identity() ), @@ -495,10 +489,7 @@ def _get_plan(self, grid, kernel_size, stride): cache_key = (grid.address, kernel_size, stride) if cache_key not in self.shared_plan_cache: self.shared_plan_cache[cache_key] = fvdb.ConvolutionPlan.from_grid_batch( - kernel_size=kernel_size, - stride=stride, - source_grid=grid, - target_grid=grid + kernel_size=kernel_size, stride=stride, source_grid=grid, target_grid=grid ) return self.shared_plan_cache[cache_key] @@ -573,7 +564,7 @@ def __init__( def forward(self, grid, feats): nvtx.range_push("PTV3_Block") grid, feats_out = self.cpe(grid, feats) - feats = grid.jagged_like(feats.jdata + feats_out.jdata) # Is this a potential issue? + feats = grid.jagged_like(feats.jdata + feats_out.jdata) # Is this a potential issue? short_cut = feats.jdata feats = grid.jagged_like(self.norm1(feats.jdata)) @@ -732,8 +723,11 @@ def __init__( self.shared_plan_cache = {} self.embedding = PTV3_Embedding( - input_dim, enc_channels[0], norm_layer_module=self.norm_layer, - embedding_mode=embedding_mode, shared_plan_cache=self.shared_plan_cache + input_dim, + enc_channels[0], + norm_layer_module=self.norm_layer, + embedding_mode=embedding_mode, + shared_plan_cache=self.shared_plan_cache, ) self.num_stages = len(enc_depths) @@ -748,8 +742,8 @@ def __init__( in_channels=enc_channels[i - 1], out_channels=enc_channels[i], norm_layer_module=self.norm_layer, + ) ) - ) # All encoder stages share the same order types; blocks within each stage cycle through them self.enc.append( PTV3_Encoder( @@ -825,7 +819,7 @@ def forward(self, grid, feats): # Shuffle order at the beginning of forward pass (matching reference implementation) shuffled_order = self._shuffle_order() - + # Store shuffled order in grid metadata so all blocks can access it grid._shuffled_order = shuffled_order @@ -840,11 +834,11 @@ def forward(self, grid, feats): # The decoder will reuse this exact shuffled order for the corresponding stage stack.append((grid, feats, shuffled_order)) grid, feats = self.enc[layer_id](grid, feats) - + # Shuffle order after pooling for the next (downsampled) stage shuffled_order = self._shuffle_order() grid._shuffled_order = shuffled_order - + nvtx.range_pop() layer_id += 1 nvtx.range_push(f"PTV3_Encoder_{layer_id}") @@ -858,11 +852,11 @@ def forward(self, grid, feats): nvtx.range_push(f"PTV3_Unpooling_{layer_id}") # Pop grid, feats, AND the shuffled_order from the corresponding encoder stage last_grid, last_feats, last_shuffled_order = stack.pop() - + # Restore the shuffled order from the encoder stage to the grids # This ensures decoder blocks use the SAME order as the corresponding encoder blocks last_grid._shuffled_order = last_shuffled_order - + grid, feats = self.dec[layer_id](grid, feats, last_grid, last_feats) # After unpooling, grid becomes last_grid with the restored shuffled order nvtx.range_pop()