Skip to content

Commit 5fe93a9

Browse files
author
muzhancun
committed
[Model] Fix STEVE-1
1 parent 979194f commit 5fe93a9

File tree

9 files changed

+24
-24
lines changed

9 files changed

+24
-24
lines changed

docs/source/inference/baseline-groot.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ The example code is provided in ``minestudio/tutorials/inference/evaluate_groot/
2525
2626
if __name__ == '__main__':
2727
ray.init()
28-
task_configs = prepare_task_configs("simple")
28+
task_configs = prepare_task_configs("simple", path="CraftJarvis/MineStudio_task_group.simple")
2929
config_file = task_configs["collect_wood"]
3030
# you can try: survive_plant, collect_wood, build_pillar, ... ; make sure the config file contains `reference_video` field
3131
print(config_file)

docs/source/inference/baseline-steve1.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ The example code is provided in ``minestudio/tutorials/inference/evaluate_steve/
4747
4848
if __name__ == '__main__':
4949
ray.init()
50-
task_configs = prepare_task_configs("simple")
50+
task_configs = prepare_task_configs("simple", path="CraftJarvis/MineStudio_task_group.simple")
5151
config_file = task_configs["collect_wood"]
5252
# you can try: survive_plant, collect_wood, build_pillar, ... ; make sure the config file contains `reference_video` field
5353
print(config_file)

docs/source/inference/baseline-vpt.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Finally, we create an ``EpisodePipeline`` object, passing ``MineGenerator`` as t
6161
),
6262
episode_filter=InfoBaseFilter(
6363
key="mine_block",
64-
val="diamond_ore",
64+
regex=".*diamond_ore.*",
6565
num=1,
6666
),
6767
)

minestudio/benchmark/utility/read_conf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
'''
22
Date: 2024-12-06 16:42:49
3-
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
4-
LastEditTime: 2025-01-07 11:08:39
3+
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
4+
LastEditTime: 2025-06-06 13:59:32
55
FilePath: /MineStudio/minestudio/benchmark/utility/read_conf.py
66
'''
77
import os
@@ -38,7 +38,7 @@ def prepare_task_configs(group_name: str, path: Optional[str] = None, refresh: b
3838
print(f"Refreshing the cache: removing existing task configs from: {local_dir}")
3939
shutil.rmtree(local_dir)
4040
if not os.path.exists(local_dir):
41-
if os.path.isdir(path):
41+
if path is not None and os.path.isdir(path):
4242
shutil.copytree(path, local_dir)
4343
else:
4444
print(f"Downloading task configs from 🤗: {path}")

minestudio/models/steve_one/body.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,9 @@ def prepare_condition(self, instruction: Dict[str, Any], deterministic: bool = F
530530
assert 'text' in instruction, "instruction must have either text or video."
531531

532532
texts = instruction['text']
533-
if isinstance(texts, str):
533+
if isinstance(texts, list):
534+
texts = texts[0]
535+
elif isinstance(texts, str):
534536
texts = [texts]
535537
assert isinstance(texts, list) and isinstance(texts[0], str), "text must be a string or a list of strings."
536538

@@ -671,16 +673,14 @@ def load_steve_one_policy(ckpt_path: str) -> SteveOnePolicy:
671673
if __name__ == '__main__':
672674
model = SteveOnePolicy.from_pretrained("CraftJarvis/MineStudio_STEVE-1.official").to("cuda")
673675
model.eval()
674-
condition = model.prepare_condition(
675-
{
676-
'cond_scale': 4.0,
677-
'video': np.random.randint(0, 255, (2, 16, 224, 224, 3)).astype(np.uint8)
678-
}
679-
)
680-
output, memory = model(condition,
676+
output, memory = model.get_action(
681677
input={
682678
'image': torch.zeros(2, 8, 128, 128, 3).to("cuda"),
683-
'condition': condition
679+
'condition': {
680+
'cond_scale': 4.0,
681+
'text': 'mine dirt',
682+
}
684683
},
685-
state_in=model.initial_state(condition, 2)
686-
)
684+
state_in=None
685+
)
686+
print(output.keys())

minestudio/tutorials/inference/evaluate_groot/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
if __name__ == '__main__':
2121
ray.init()
22-
task_configs = prepare_task_configs("simple")
22+
task_configs = prepare_task_configs("simple", path="CraftJarvis/MineStudio_task_group.simple")
2323
config_file = task_configs["collect_wood"]
2424
# you can try: survive_plant, collect_wood, build_pillar, ... ; make sure the config file contains `reference_video` field
2525
print(config_file)

minestudio/tutorials/inference/evaluate_steve/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def after_step(self, sim, obs, reward, terminated, truncated, info):
3535

3636
if __name__ == '__main__':
3737
ray.init()
38-
task_configs = prepare_task_configs("simple")
38+
task_configs = prepare_task_configs("simple", path="CraftJarvis/MineStudio_task_group.simple")
3939
config_file = task_configs["collect_wood"]
4040
# you can try: survive_plant, collect_wood, build_pillar, ... ; make sure the config file contains `reference_video` field
4141
print(config_file)

minestudio/tutorials/inference/evaluate_vpts/mine_diamond.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
'''
22
Date: 2024-11-25 08:11:33
3-
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
4-
LastEditTime: 2025-01-04 11:39:20
3+
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
4+
LastEditTime: 2025-06-06 14:02:15
55
FilePath: /MineStudio/minestudio/tutorials/inference/evaluate_vpts/mine_diamond.py
66
'''
77
import ray
@@ -41,7 +41,7 @@
4141
),
4242
episode_filter=InfoBaseFilter(
4343
key="mine_block",
44-
val="diamond_ore",
44+
regex=".*diamond_ore.*",
4545
num=1,
4646
),
4747
)

minestudio/tutorials/inference/evaluate_vpts/shoot_animals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
'''
22
Date: 2024-12-13 14:31:12
3-
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
4-
LastEditTime: 2025-01-04 13:54:09
3+
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
4+
LastEditTime: 2025-06-06 14:03:50
55
FilePath: /MineStudio/minestudio/tutorials/inference/evaluate_vpts/shoot_animals.py
66
'''
77
import ray

0 commit comments

Comments
 (0)