Skip to content

Commit 50098f8

Browse files
committed
supports external methods
1 parent 88fbb04 commit 50098f8

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

modules/trainer_tab.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import webbrowser
44
import argparse
55
import gradio as gr
6+
from traitlets import Instance
67

78
from nerfstudio.configs import dataparser_configs as dc, method_configs as mc
9+
from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig
810
from utils.trainer import WebUITrainer
911
from utils.utils import run_cmd, get_folder_path, browse_folder, submit, generate_args
1012
from nerfstudio.viewer_legacy.server import viewer_utils
@@ -40,6 +42,14 @@ def __init__(self, args: argparse.Namespace):
4042
self.dist_url = args.dist_url
4143
self.user_websocket_port = args.websocket_port
4244

45+
self.use_external_methods = args.use_external_methods
46+
if self.use_external_methods:
47+
self.method_descriptions = mc.all_descriptions
48+
self.dataparsers = dc.all_dataparsers
49+
else:
50+
self.method_descriptions = mc.descriptions
51+
self.dataparsers = dc.dataparsers
52+
4353
self.websocket_port = None
4454

4555
def setup_ui(self):
@@ -105,15 +115,15 @@ def setup_ui(self):
105115
with gr.Row():
106116
with gr.Column():
107117
method = gr.Radio(
108-
choices=list(mc.descriptions.keys()), label="Method"
118+
choices=list(self.method_descriptions.keys()), label="Method"
109119
)
110120
description = gr.Textbox(label="Description", visible=True)
111121
method.change(
112122
self.get_model_description, inputs=method, outputs=description
113123
)
114124
with gr.Column():
115125
dataparser = gr.Radio(
116-
choices=["default"] + list(dc.dataparsers.keys()),
126+
choices=["default"] + list(self.dataparsers.keys()),
117127
label="Data Parser",
118128
value="default",
119129
)
@@ -133,10 +143,16 @@ def setup_ui(self):
133143
)
134144

135145
with gr.Accordion("Model Config", open=False):
136-
for key, value in mc.descriptions.items():
146+
for key, value in self.method_descriptions.items():
137147
with gr.Group(visible=False) as group:
138-
if key in mc.method_configs:
139-
model_config = mc.method_configs[key].pipeline.model # type: ignore
148+
if key in mc.all_methods:
149+
if (
150+
type(mc.all_methods[key])
151+
is ExternalMethodDummyTrainerConfig
152+
):
153+
continue
154+
155+
model_config = mc.all_methods[key].pipeline.model # type: ignore
140156
generated_args, labels = generate_args(
141157
model_config, visible=True
142158
)
@@ -155,7 +171,7 @@ def setup_ui(self):
155171
)
156172

157173
with gr.Accordion("Data Parser Config", open=False):
158-
for key, parser_config in dc.dataparsers.items():
174+
for key, parser_config in self.dataparsers.items():
159175
with gr.Group(visible=False) as group:
160176
generated_args, labels = generate_args(
161177
parser_config, visible=True
@@ -315,7 +331,7 @@ def run_train(
315331
config.viewer.websocket_port = self.websocket_port
316332

317333
if data_parser != "default":
318-
config.pipeline.datamanager.dataparser = dc.all_dataparsers[data_parser]
334+
config.pipeline.datamanager.dataparser = self.dataparsers[data_parser]
319335
for key, value in self.dataparser_args.items():
320336
setattr(config.pipeline.datamanager.dataparser, key, value)
321337

@@ -406,7 +422,7 @@ def get_data_parser_args(self, dataparser, *args):
406422
self.dataparser_args = temp_args
407423

408424
def get_model_description(self, method):
409-
return mc.descriptions[method]
425+
return self.method_descriptions[method]
410426

411427
def update_dataparser_args_visibility(self, dataparser):
412428
# print(group_keys)

webui.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ def launch(self, **kwargs):
138138
help="Disable the Exporter tab",
139139
)
140140

141+
parser.add_argument(
142+
"--use_external_methods",
143+
action="store_true",
144+
default=False,
145+
help="Use external methods in the Trainer tab",
146+
)
147+
141148
parsed_args: argparse.Namespace = parser.parse_args()
142149

143150
app = WebUI(parsed_args)

0 commit comments

Comments
 (0)