Skip to content

Commit 1b4258d

Browse files
quaglacopybara-github
authored andcommitted
Change set_default and default to classname in the bindings.
Having defaults in the names of the attribute or function names is confusing for the users since the defaults are not re-applied but are only used for writing to XML. The alternative would be to change mjs_setDefault to re-apply them, but this would overwrite any other attribute set by the user so far. PiperOrigin-RevId: 726431090 Change-Id: I41f4a23e1722e278b0ecc0a00b07db9b384fe994
1 parent 89253d9 commit 1b4258d

File tree

2 files changed

+88
-109
lines changed

2 files changed

+88
-109
lines changed

python/mujoco/specs.cc

Lines changed: 55 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -595,16 +595,14 @@ PYBIND11_MODULE(_specs, m) {
595595
[](raw::MjsBody& self, raw::MjsFrame& frame) -> void {
596596
mjs_setFrame(self.element, &frame);
597597
});
598-
mjsBody.def("set_default",
599-
[](raw::MjsBody& self, raw::MjsDefault& default_) -> void {
600-
mjs_setDefault(self.element, &default_);
601-
});
602-
mjsBody.def(
603-
"default",
598+
mjsBody.def_property(
599+
"classname",
604600
[](raw::MjsBody& self) -> raw::MjsDefault* {
605601
return mjs_getDefault(self.element);
606602
},
607-
py::return_value_policy::reference_internal);
603+
[](raw::MjsBody& self, raw::MjsDefault& default_) -> void {
604+
mjs_setDefault(self.element, &default_);
605+
});
608606
mjsBody.def(
609607
"find_all",
610608
[](raw::MjsBody& self, mjtObj objtype) -> py::list {
@@ -837,63 +835,60 @@ PYBIND11_MODULE(_specs, m) {
837835
mjsGeom.def("set_frame", [](raw::MjsGeom& self, raw::MjsFrame& frame) {
838836
mjs_setFrame(self.element, &frame);
839837
});
840-
mjsGeom.def("set_default", [](raw::MjsGeom& self, raw::MjsDefault& def) {
841-
mjs_setDefault(self.element, &def);
842-
});
843838
mjsGeom.def_property_readonly(
844839
"parent",
845840
[](raw::MjsGeom& self) -> raw::MjsBody* {
846841
return mjs_getParent(self.element);
847842
},
848843
py::return_value_policy::reference_internal);
849-
mjsGeom.def(
850-
"default",
844+
mjsGeom.def_property(
845+
"classname",
851846
[](raw::MjsGeom& self) -> raw::MjsDefault* {
852847
return mjs_getDefault(self.element);
853848
},
854-
py::return_value_policy::reference_internal);
849+
[](raw::MjsGeom& self, raw::MjsDefault& default_) -> void {
850+
mjs_setDefault(self.element, &default_);
851+
});
855852

856853
// ============================= MJSJOINT ====================================
857854
mjsJoint.def("delete", [](raw::MjsJoint& self) { mjs_delete(self.element); });
858855
mjsJoint.def("set_frame", [](raw::MjsJoint& self, raw::MjsFrame& frame) {
859856
mjs_setFrame(self.element, &frame);
860857
});
861-
mjsJoint.def("set_default", [](raw::MjsJoint& self, raw::MjsDefault& def) {
862-
mjs_setDefault(self.element, &def);
863-
});
864858
mjsJoint.def_property_readonly(
865859
"parent",
866860
[](raw::MjsJoint& self) -> raw::MjsBody* {
867861
return mjs_getParent(self.element);
868862
},
869863
py::return_value_policy::reference_internal);
870-
mjsJoint.def(
871-
"default",
864+
mjsJoint.def_property(
865+
"classname",
872866
[](raw::MjsJoint& self) -> raw::MjsDefault* {
873867
return mjs_getDefault(self.element);
874868
},
875-
py::return_value_policy::reference_internal);
869+
[](raw::MjsJoint& self, raw::MjsDefault& default_) -> void {
870+
mjs_setDefault(self.element, &default_);
871+
});
876872

877873
// ============================= MJSSITE =====================================
878874
mjsSite.def("delete", [](raw::MjsSite& self) { mjs_delete(self.element); });
879875
mjsSite.def("set_frame", [](raw::MjsSite& self, raw::MjsFrame& frame) {
880876
mjs_setFrame(self.element, &frame);
881877
});
882-
mjsSite.def("set_default", [](raw::MjsSite& self, raw::MjsDefault& def) {
883-
mjs_setDefault(self.element, &def);
884-
});
885878
mjsSite.def_property_readonly(
886879
"parent",
887880
[](raw::MjsSite& self) -> raw::MjsBody* {
888881
return mjs_getParent(self.element);
889882
},
890883
py::return_value_policy::reference_internal);
891-
mjsSite.def(
892-
"default",
884+
mjsSite.def_property(
885+
"classname",
893886
[](raw::MjsSite& self) -> raw::MjsDefault* {
894887
return mjs_getDefault(self.element);
895888
},
896-
py::return_value_policy::reference_internal);
889+
[](raw::MjsSite& self, raw::MjsDefault& default_) -> void {
890+
mjs_setDefault(self.element, &default_);
891+
});
897892
mjsSite.def(
898893
"attach_body",
899894
[](raw::MjsSite& self, raw::MjsBody& body,
@@ -918,115 +913,102 @@ PYBIND11_MODULE(_specs, m) {
918913
mjsCamera.def("set_frame", [](raw::MjsCamera& self, raw::MjsFrame& frame) {
919914
mjs_setFrame(self.element, &frame);
920915
});
921-
mjsCamera.def("set_default", [](raw::MjsCamera& self, raw::MjsDefault& def) {
922-
mjs_setDefault(self.element, &def);
923-
});
924916
mjsCamera.def_property_readonly(
925917
"parent",
926918
[](raw::MjsCamera& self) -> raw::MjsBody* {
927919
return mjs_getParent(self.element);
928920
},
929921
py::return_value_policy::reference_internal);
930-
mjsCamera.def(
931-
"default",
922+
mjsCamera.def_property(
923+
"classname",
932924
[](raw::MjsCamera& self) -> raw::MjsDefault* {
933925
return mjs_getDefault(self.element);
934926
},
935-
py::return_value_policy::reference_internal);
927+
[](raw::MjsCamera& self, raw::MjsDefault& default_) -> void {
928+
mjs_setDefault(self.element, &default_);
929+
});
936930

937931
// ============================= MJSLIGHT ====================================
938932
mjsLight.def("delete", [](raw::MjsLight& self) { mjs_delete(self.element); });
939933
mjsLight.def("set_frame", [](raw::MjsLight& self, raw::MjsFrame& frame) {
940934
mjs_setFrame(self.element, &frame);
941935
});
942-
mjsLight.def("set_default", [](raw::MjsLight& self, raw::MjsDefault& def) {
943-
mjs_setDefault(self.element, &def);
944-
});
945936
mjsLight.def_property_readonly(
946937
"parent",
947938
[](raw::MjsLight& self) -> raw::MjsBody* {
948939
return mjs_getParent(self.element);
949940
},
950941
py::return_value_policy::reference_internal);
951-
mjsLight.def(
952-
"default",
942+
mjsLight.def_property(
943+
"classname",
953944
[](raw::MjsLight& self) -> raw::MjsDefault* {
954945
return mjs_getDefault(self.element);
955946
},
956-
py::return_value_policy::reference_internal);
947+
[](raw::MjsLight& self, raw::MjsDefault& default_) -> void {
948+
mjs_setDefault(self.element, &default_);
949+
});
957950

958951
// ============================= MJSMATERIAL =================================
959952
mjsMaterial.def("delete",
960953
[](raw::MjsMaterial& self) { mjs_delete(self.element); });
961-
mjsMaterial.def("set_default",
962-
[](raw::MjsMaterial& self, raw::MjsDefault& def) {
963-
mjs_setDefault(self.element, &def);
964-
});
965-
mjsMaterial.def(
966-
"default",
954+
mjsMaterial.def_property(
955+
"classname",
967956
[](raw::MjsMaterial& self) -> raw::MjsDefault* {
968957
return mjs_getDefault(self.element);
969958
},
970-
py::return_value_policy::reference_internal);
959+
[](raw::MjsMaterial& self, raw::MjsDefault& default_) -> void {
960+
mjs_setDefault(self.element, &default_);
961+
});
971962

972963
// ============================= MJSMESH =====================================
973964
mjsMesh.def("delete", [](raw::MjsMesh& self) { mjs_delete(self.element); });
974-
mjsMesh.def("set_default", [](raw::MjsMesh& self, raw::MjsDefault& def) {
975-
mjs_setDefault(self.element, &def);
976-
});
977-
mjsMesh.def(
978-
"default",
965+
mjsMesh.def_property(
966+
"classname",
979967
[](raw::MjsMesh& self) -> raw::MjsDefault* {
980968
return mjs_getDefault(self.element);
981969
},
982-
py::return_value_policy::reference_internal);
970+
[](raw::MjsMesh& self, raw::MjsDefault& default_) -> void {
971+
mjs_setDefault(self.element, &default_);
972+
});
983973

984974
// ============================= MJSPAIR =====================================
985975
mjsPair.def("delete", [](raw::MjsPair& self) { mjs_delete(self.element); });
986-
mjsPair.def("set_default", [](raw::MjsPair& self, raw::MjsDefault& def) {
987-
mjs_setDefault(self.element, &def);
988-
});
989-
mjsPair.def(
990-
"default",
976+
mjsPair.def_property(
977+
"classname",
991978
[](raw::MjsPair& self) -> raw::MjsDefault* {
992979
return mjs_getDefault(self.element);
993980
},
994-
py::return_value_policy::reference_internal);
981+
[](raw::MjsPair& self, raw::MjsDefault& default_) -> void {
982+
mjs_setDefault(self.element, &default_);
983+
});
995984

996985
// ============================= MJSEQUAL ====================================
997986
mjsEquality.def("delete",
998987
[](raw::MjsEquality& self) { mjs_delete(self.element); });
999-
mjsEquality.def("set_default",
1000-
[](raw::MjsEquality& self, raw::MjsDefault& def) {
1001-
mjs_setDefault(self.element, &def);
1002-
});
1003-
mjsEquality.def(
1004-
"default",
988+
mjsEquality.def_property(
989+
"classname",
1005990
[](raw::MjsEquality& self) -> raw::MjsDefault* {
1006991
return mjs_getDefault(self.element);
1007992
},
1008-
py::return_value_policy::reference_internal);
993+
[](raw::MjsEquality& self, raw::MjsDefault& default_) -> void {
994+
mjs_setDefault(self.element, &default_);
995+
});
1009996

1010997
// ============================= MJSACTUATOR =================================
1011998
mjsActuator.def("delete",
1012999
[](raw::MjsActuator& self) { mjs_delete(self.element); });
1013-
mjsActuator.def("set_default",
1014-
[](raw::MjsActuator& self, raw::MjsDefault& def) {
1015-
mjs_setDefault(self.element, &def);
1016-
});
1017-
mjsActuator.def(
1018-
"default",
1000+
mjsActuator.def_property(
1001+
"classname",
10191002
[](raw::MjsActuator& self) -> raw::MjsDefault* {
10201003
return mjs_getDefault(self.element);
10211004
},
1022-
py::return_value_policy::reference_internal);
1005+
[](raw::MjsActuator& self, raw::MjsDefault& default_) -> void {
1006+
mjs_setDefault(self.element, &default_);
1007+
});
10231008

10241009
// ============================= MJSTENDON ===================================
10251010
mjsTendon.def("delete",
10261011
[](raw::MjsTendon& self) { mjs_delete(self.element); });
1027-
mjsTendon.def("set_default", [](raw::MjsTendon& self, raw::MjsDefault& def) {
1028-
mjs_setDefault(self.element, &def);
1029-
});
10301012
mjsTendon.def(
10311013
"default",
10321014
[](raw::MjsTendon& self) -> raw::MjsDefault* {

python/mujoco/specs_test.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -528,22 +528,7 @@ def test_uncompiled_spec_cannot_be_written(self):
528528
spec.to_xml()
529529

530530
def test_modelname_default_class(self):
531-
spec = mujoco.MjSpec()
532-
spec.modelname = 'test'
533-
534-
main = spec.default()
535-
main.geom.size[0] = 2
536-
537-
def1 = spec.add_default('def1', main)
538-
def1.geom.size[0] = 3
539-
540-
spec.worldbody.add_geom(def1)
541-
spec.worldbody.add_geom(main)
542-
543-
spec.compile()
544-
self.assertEqual(
545-
spec.to_xml(),
546-
textwrap.dedent("""\
531+
XML = textwrap.dedent("""\
547532
<mujoco model="test">
548533
<compiler angle="radian"/>
549534
@@ -559,8 +544,8 @@ def test_modelname_default_class(self):
559544
<geom/>
560545
</worldbody>
561546
</mujoco>
562-
"""),
563-
)
547+
""")
548+
564549
spec = mujoco.MjSpec()
565550
spec.modelname = 'test'
566551

@@ -574,26 +559,38 @@ def test_modelname_default_class(self):
574559
spec.worldbody.add_geom(main)
575560

576561
spec.compile()
577-
self.assertEqual(
578-
spec.to_xml(),
579-
textwrap.dedent("""\
580-
<mujoco model="test">
581-
<compiler angle="radian"/>
562+
self.assertEqual(spec.to_xml(), XML)
563+
spec = mujoco.MjSpec()
564+
spec.modelname = 'test'
582565

583-
<default>
584-
<geom size="2 0 0"/>
585-
<default class="def1">
586-
<geom size="3 0 0"/>
587-
</default>
588-
</default>
566+
main = spec.default()
567+
main.geom.size[0] = 2
568+
def1 = spec.add_default('def1', main)
569+
def1.geom.size[0] = 3
589570

590-
<worldbody>
591-
<geom class="def1"/>
592-
<geom/>
593-
</worldbody>
594-
</mujoco>
595-
"""),
596-
)
571+
geom1 = spec.worldbody.add_geom(def1)
572+
geom2 = spec.worldbody.add_geom()
573+
self.assertEqual(geom1.classname.name, 'def1')
574+
self.assertEqual(geom2.classname.name, 'main')
575+
576+
spec.compile()
577+
self.assertEqual(spec.to_xml(), XML)
578+
579+
spec = mujoco.MjSpec()
580+
spec.modelname = 'test'
581+
582+
main = spec.default()
583+
main.geom.size[0] = 2
584+
def1 = spec.add_default('def1', main)
585+
def1.geom.size[0] = 3
586+
587+
geom1 = spec.worldbody.add_geom(size=[3, 0, 0])
588+
geom2 = spec.worldbody.add_geom(size=[2, 0, 0])
589+
geom1.classname = def1
590+
geom2.classname = main # actually redundant, since main is always applied
591+
592+
spec.compile()
593+
self.assertEqual(spec.to_xml(), XML)
597594

598595
def test_element_list(self):
599596
spec = mujoco.MjSpec()

0 commit comments

Comments
 (0)