Skip to content

Commit 7f36d55

Browse files
author
Eric Cornwell
committed
Saving checkpoint to s3, enabled NeRF output, increased training timeout, pre-downloading u2net and alexnet models, improved vanilla bg remover performance using built-in folder option, refined equirectangular_to_perspective.py to output angled connective views
1 parent 445c584 commit 7f36d55

File tree

6 files changed

+1070
-272
lines changed

6 files changed

+1070
-272
lines changed

source/Gradio/generate_splat_gradio.py

Lines changed: 212 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
import boto3.s3.transfer
3131

3232
print(f"Gradio Version: {gr.__version__}")
33-
UPLOAD_DIR = os.path.join(os.path.dirname(__file__), "uploads")
34-
os.makedirs(UPLOAD_DIR, exist_ok=True)
3533

3634
class SharedState:
3735
def __init__(self):
@@ -182,8 +180,10 @@ def refresh_s3_contents():
182180
if item['Key'].endswith('/'):
183181
continue
184182

185-
# Check if file is .ply or .spz
186-
if not (item['Key'].lower().endswith('.ply') or item['Key'].lower().endswith('.spz')):
183+
# Check if file is .ply, .spz, or .glb
184+
if not (item['Key'].lower().endswith('.ply') or
185+
item['Key'].lower().endswith('.spz') or
186+
item['Key'].lower().endswith('.glb')):
187187
continue
188188

189189
# Get the job ID from the path
@@ -715,7 +715,8 @@ def create_advanced_settings_tab():
715715
"splatfacto-mcmc",
716716
"splatfacto-w-light",
717717
"3dgut",
718-
"3dgrt"
718+
"3dgrt",
719+
"nerfacto"
719720
],
720721
value="splatfacto"
721722
)
@@ -807,11 +808,11 @@ def on_select(evt: gr.SelectData, data):
807808
gr.update(interactive=False)
808809
]
809810

810-
def handle_view(selected_row):
811-
"""Handle view button click"""
811+
def handle_view_with_progress(selected_row):
812+
"""Handle view button click with progress bar"""
812813
try:
813814
if not selected_row:
814-
return gr.update(value=None), "No file selected"
815+
return gr.update(value=None), "No file selected", ""
815816

816817
bucket_name = shared_state.s3_bucket
817818
output_prefix = shared_state.s3_output or "workflow-output"
@@ -820,8 +821,17 @@ def handle_view(selected_row):
820821
filename = selected_row[1]
821822
file_key = f"{output_prefix}/{job_id}/{filename}"
822823

824+
# Check if this is the currently loaded model
825+
current_url = getattr(shared_state, 'current_model_url', None)
826+
current_key = getattr(shared_state, 'current_model_key', None)
827+
828+
if current_key == file_key:
829+
# Model is already loaded, don't show progress bar
830+
return gr.update(value=current_url), f"Already loaded: {filename}", ""
831+
823832
# Get file size information
824833
file_size_mb = None
834+
size_info = ""
825835
try:
826836
s3_client = boto3.client('s3')
827837
response = s3_client.head_object(Bucket=bucket_name, Key=file_key)
@@ -830,25 +840,114 @@ def handle_view(selected_row):
830840
size_info = f" ({file_size_mb:.1f} MB)"
831841
except Exception as e:
832842
print(f"Error getting file size: {str(e)}")
833-
size_info = ""
834-
835-
# Note: Removed artificial delay as it's not necessary for functionality
836843

837844
# Generate a presigned URL for the file
838845
presigned_url = generate_presigned_url(bucket_name, file_key)
839846

840847
if not presigned_url:
841-
return gr.update(value=None), "Error generating URL"
848+
return gr.update(value=None), "Error generating URL", ""
849+
850+
# Store current model info
851+
shared_state.current_model_url = presigned_url
852+
shared_state.current_model_key = file_key
842853

843-
# Return the URL for the Model3D component
844-
return gr.update(value=presigned_url), f"Loading {filename}{size_info}..."
854+
# Estimate loading time based on file size
855+
estimated_time = file_size_mb * 0.5 if file_size_mb else 25
856+
857+
# Create a unique ID for this model
858+
model_id = f"{job_id}_{filename.replace('.', '_')}"
859+
860+
# Track loaded models in shared state
861+
if not hasattr(shared_state, 'loaded_models'):
862+
shared_state.loaded_models = set()
863+
864+
# Only show progress bar for models not yet loaded
865+
if file_key not in shared_state.loaded_models:
866+
shared_state.loaded_models.add(file_key)
867+
868+
# Create progress bar HTML
869+
progress_html = f"""
870+
<div style="margin: 10px 0;">
871+
<div style="background: #f0f0f0; border-radius: 10px; overflow: hidden; height: 20px;">
872+
<div style="background: linear-gradient(90deg, #4CAF50, #45a049); height: 100%; width: 0%; animation: loading {estimated_time}s ease-out forwards;"></div>
873+
</div>
874+
<div style="text-align: center; margin-top: 5px; font-size: 14px;">Loading {filename}{size_info}... (~{estimated_time:.0f}s estimated)</div>
875+
</div>
876+
<style>
877+
@keyframes loading {{
878+
0% {{ width: 0%; }}
879+
30% {{ width: 40%; }}
880+
60% {{ width: 70%; }}
881+
90% {{ width: 90%; }}
882+
100% {{ width: 100%; }}
883+
}}
884+
</style>
885+
"""
886+
else:
887+
# Empty progress HTML if already loaded
888+
progress_html = ""
889+
890+
# Estimate loading time based on file size
891+
# Model based on actual loading times:
892+
# 12MB=6sec, 31MB=9sec, 180MB=75sec, 236MB=105sec, 448MB=220sec
893+
if file_size_mb is None:
894+
estimated_time = 10
895+
else:
896+
# Quadratic model: time = 0.001x² + 0.3x + 3
897+
estimated_time = 0.001 * (file_size_mb ** 2) + 0.3 * file_size_mb + 3
898+
899+
# Only show progress bar when the View button is clicked
900+
# Check if we're navigating between tabs by looking at the referrer
901+
progress_html = f"""
902+
<div style="margin: 10px 0;">
903+
<div style="background: #f0f0f0; border-radius: 10px; overflow: hidden; height: 20px;">
904+
<div style="background: linear-gradient(90deg, #4CAF50, #45a049); height: 100%; width: 0%; animation: loading {estimated_time}s ease-out forwards;"></div>
905+
</div>
906+
<div style="text-align: center; margin-top: 5px; font-size: 14px;">Loading {filename}{size_info}... (~{estimated_time:.0f}s estimated)</div>
907+
</div>
908+
<style>
909+
@keyframes loading {{
910+
0% {{ width: 0%; }}
911+
30% {{ width: 40%; }}
912+
60% {{ width: 70%; }}
913+
90% {{ width: 90%; }}
914+
100% {{ width: 100%; }}
915+
}}
916+
</style>
917+
<script>
918+
(function() {{
919+
// Check if this is a tab navigation by looking at document.referrer
920+
const isTabNavigation = document.referrer.includes(window.location.origin);
921+
922+
// If this is tab navigation, hide the progress bar
923+
if (isTabNavigation) {{
924+
// Find all progress bars and hide them
925+
const progressBars = document.querySelectorAll('div[style*="margin: 10px 0;"]');
926+
progressBars.forEach(bar => {{
927+
bar.style.display = 'none';
928+
}});
929+
}}
930+
}})();
931+
</script>
932+
"""
933+
934+
# Create a unique ID for this model
935+
model_id = f"{job_id}_{filename.replace('.', '_')}"
936+
937+
# Return all three required outputs
938+
return gr.update(value=presigned_url), f"Loading {filename}...", progress_html
845939

846940
except Exception as e:
847941
error_msg = f"Error viewing file: {str(e)}"
848-
print(f"[DEBUG] Error in handle_view: {error_msg}")
942+
print(f"[DEBUG] Error in handle_view_with_progress: {error_msg}")
849943
import traceback
850944
traceback.print_exc()
851-
return gr.update(value=None), error_msg
945+
return gr.update(value=None), error_msg, ""
946+
947+
def handle_view(selected_row):
948+
"""Handle view button click"""
949+
result = handle_view_with_progress(selected_row)
950+
return result[0], result[1]
852951

853952
def add_to_favorites(selected_data):
854953
"""Add currently selected item to favorites"""
@@ -1194,6 +1293,12 @@ def update_favorites_ui():
11941293
with favorites_container:
11951294
update_favorites_ui()
11961295

1296+
# Progress bar HTML component - always visible
1297+
progress_bar = gr.HTML(
1298+
value="",
1299+
visible=True
1300+
)
1301+
11971302
# Now render the viewer
11981303
viewer.render()
11991304
viewer_status.render()
@@ -1365,11 +1470,94 @@ def refresh_button_handler():
13651470
outputs=[download_iframe]
13661471
)
13671472

1473+
# Add a JavaScript function to track loaded models
1474+
tracking_js = gr.HTML("""
1475+
<script>
1476+
// Make sure we have the global tracking object
1477+
if (typeof window.loadedModels === 'undefined') {
1478+
window.loadedModels = {};
1479+
}
1480+
</script>
1481+
""")
1482+
1483+
# Simplest approach - directly implement the progress bar
1484+
def handle_view_with_progress(selected_row):
1485+
try:
1486+
if not selected_row:
1487+
return gr.update(value=None), "No file selected", ""
1488+
1489+
bucket_name = shared_state.s3_bucket
1490+
output_prefix = shared_state.s3_output or "workflow-output"
1491+
1492+
job_id = selected_row[0]
1493+
filename = selected_row[1]
1494+
file_key = f"{output_prefix}/{job_id}/{filename}"
1495+
1496+
# Check if this is the currently loaded model
1497+
current_url = getattr(shared_state, 'current_model_url', None)
1498+
current_key = getattr(shared_state, 'current_model_key', None)
1499+
1500+
if current_key == file_key:
1501+
# Model is already loaded, don't show progress bar
1502+
return gr.update(value=current_url), f"Already loaded: {filename}", ""
1503+
1504+
# Get file size information
1505+
file_size_mb = None
1506+
size_info = ""
1507+
try:
1508+
s3_client = boto3.client('s3')
1509+
response = s3_client.head_object(Bucket=bucket_name, Key=file_key)
1510+
file_size = response['ContentLength']
1511+
file_size_mb = file_size / (1024 * 1024)
1512+
size_info = f" ({file_size_mb:.1f} MB)"
1513+
except Exception as e:
1514+
print(f"Error getting file size: {str(e)}")
1515+
1516+
# Generate a presigned URL for the file
1517+
presigned_url = generate_presigned_url(bucket_name, file_key)
1518+
1519+
if not presigned_url:
1520+
return gr.update(value=None), "Error generating URL", ""
1521+
1522+
# Store current model info
1523+
shared_state.current_model_url = presigned_url
1524+
shared_state.current_model_key = file_key
1525+
1526+
# Estimate loading time based on file size
1527+
estimated_time = file_size_mb * 0.5 if file_size_mb else 25
1528+
1529+
# Create progress bar HTML
1530+
progress_html = f"""
1531+
<div style="margin: 10px 0;">
1532+
<div style="background: #f0f0f0; border-radius: 10px; overflow: hidden; height: 20px;">
1533+
<div style="background: linear-gradient(90deg, #4CAF50, #45a049); height: 100%; width: 0%; animation: loading {estimated_time}s ease-out forwards;"></div>
1534+
</div>
1535+
<div style="text-align: center; margin-top: 5px; font-size: 14px;">Loading {filename}{size_info}... (~{estimated_time:.0f}s estimated)</div>
1536+
</div>
1537+
<style>
1538+
@keyframes loading {{
1539+
0% {{ width: 0%; }}
1540+
30% {{ width: 40%; }}
1541+
60% {{ width: 70%; }}
1542+
90% {{ width: 90%; }}
1543+
100% {{ width: 100%; }}
1544+
}}
1545+
</style>
1546+
"""
1547+
1548+
return gr.update(value=presigned_url), f"Loading {filename}...", progress_html
1549+
1550+
except Exception as e:
1551+
error_msg = f"Error viewing file: {str(e)}"
1552+
print(f"[DEBUG] Error in handle_view_with_progress: {error_msg}")
1553+
import traceback
1554+
traceback.print_exc()
1555+
return gr.update(value=None), error_msg, ""
1556+
13681557
view_btn.click(
1369-
fn=handle_view,
1558+
fn=handle_view_with_progress,
13701559
inputs=[selected_data],
1371-
outputs=[viewer, viewer_status],
1372-
show_progress="full" # Show full progress animation
1560+
outputs=[viewer, viewer_status, progress_bar]
13731561
)
13741562

13751563
# Update this to use the new function that updates the UI
@@ -1566,6 +1754,11 @@ def load_favorites():
15661754
def create_interface():
15671755
# Create the main Gradio interface
15681756
with gr.Blocks(title="Open Source 3D Reconstruction Toolbox for Gaussian Splats on AWS", theme=gr.themes.Ocean(), css="""
1757+
/* Add global tracking script */
1758+
<script>
1759+
// Global variable to track loaded models
1760+
window.loadedModels = {};
1761+
</script>
15691762
#viewer-container {
15701763
width: 100%;
15711764
height: 600px;

source/container/Dockerfile

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ RUN pip install git+https://github.com/nerfstudio-project/gsplat.git
231231
RUN git clone https://github.com/KevinXu02/splatfacto-w
232232
RUN pip install git+https://github.com/KevinXu02/splatfacto-w
233233

234+
# Pre-download AlexNet model
235+
RUN mkdir -p /root/.cache/torch/hub/checkpoints && \
236+
wget -O /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth \
237+
https://download.pytorch.org/models/alexnet-owt-7be5be79.pth
238+
239+
RUN wget -O /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth \
240+
https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
241+
234242
# Install AWS CLI v2 latest
235243
RUN wget "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -O "awscliv2.zip" && \
236244
unzip awscliv2.zip && \
@@ -242,8 +250,25 @@ RUN pip install --no-cache-dir scipy 'numpy<2.0.0' Pillow>=10.3.0 timm einops om
242250

243251
# Install the Background Removal Classic Tool
244252
RUN git clone https://github.com/nadermx/backgroundremover.git
253+
# Comment out print statement "DEBUG: path to be checked: /root/.u2net/u2net.pth"
254+
RUN sed -i '53s/^ print/ #print/' ${CODE_PATH}/backgroundremover/backgroundremover/u2net/detect.py
245255
RUN pip install -r ${CODE_PATH}/backgroundremover/requirements.txt
246256

257+
# Download U2NET models
258+
RUN mkdir -p /root/.u2net && \
259+
cd /root/.u2net && \
260+
wget -O u2net.pth.part1 https://github.com/nadermx/backgroundremover/raw/main/models/u2aa && \
261+
wget -O u2net.pth.part2 https://github.com/nadermx/backgroundremover/raw/main/models/u2ab && \
262+
wget -O u2net.pth.part3 https://github.com/nadermx/backgroundremover/raw/main/models/u2ac && \
263+
wget -O u2net.pth.part4 https://github.com/nadermx/backgroundremover/raw/main/models/u2ad && \
264+
cat u2net.pth.part* > u2net.pth && \
265+
rm u2net.pth.part* && \
266+
wget -O u2netp.pth https://github.com/nadermx/backgroundremover/raw/main/models/u2netp.pth
267+
268+
# Set environment variables for the models
269+
ENV U2NET_PATH=/root/.u2net/u2net.pth
270+
ENV U2NETP_PATH=/root/.u2net/u2netp.pth
271+
247272
# Install the SAM2 Segmentation Model and Code
248273
RUN git clone https://github.com/facebookresearch/sam2.git
249274
RUN mv ${CODE_PATH}/sam2 ${CODE_PATH}/sam
@@ -256,7 +281,9 @@ RUN sed -i 's/"torch>=2.5.1",/#"torch>=2.5.1",/' setup.py && \
256281
sed -i 's/"hydra-core>=1.3.2",/"hydra-core==1.3.2",/' setup.py
257282

258283
# Install SAM2 dependencies
259-
RUN pip install -e .
284+
#RUN pip install -e .
285+
RUN pip install --no-deps -e . && \
286+
pip install hydra-core==1.3.2 cog>=0.14.12 opencv-python matplotlib
260287
WORKDIR ${CODE_PATH}
261288

262289
# Build NVIDIA 3DGRUT

0 commit comments

Comments
 (0)