Skip to content

Commit f5de0ed

Browse files
dpfausaran-t
authored andcommitted
Add code for generating s3o4d data
PiperOrigin-RevId: 535584807
1 parent 9176a9f commit f5de0ed

File tree

1 file changed

+172
-0
lines changed

1 file changed

+172
-0
lines changed

geomancer/data_writer.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2023 DeepMind Technologies Limited.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Data writer for Stanford Bunny experiments and other objects."""
16+
17+
# pylint: disable=unused-import
18+
import copy
19+
import io
20+
import os
21+
import time
22+
23+
from absl import app
24+
from absl import flags
25+
from absl import logging
26+
27+
from dm_control import mujoco
28+
29+
import numpy as np
30+
31+
32+
_SHARD = flags.DEFINE_integer('shard', 0, 'Shard index')
33+
_SIZE = flags.DEFINE_integer('size', 1000,
34+
'Number of images to save to a shard')
35+
_OBJECT = flags.DEFINE_string('object', 'dragon', 'Which object to render')
36+
_PATH = flags.DEFINE_string('path', '', 'Path to folder with .stl files')
37+
38+
39+
render_height = 1024
40+
render_width = 1024
41+
42+
height = 256
43+
width = 256
44+
45+
46+
def get_normal(x):
47+
"""Get vectors normal to a unit vector."""
48+
_, _, v = np.linalg.svd(x[None, :])
49+
return v[:, 1:]
50+
51+
52+
def render(quat, light, mesh='bunny', meshdir='data'):
53+
"""Script to render an image."""
54+
scale, pos = None, None
55+
if mesh == 'bunny':
56+
scale = 0.03
57+
pos = -1.0
58+
elif mesh == 'dragon':
59+
scale = 0.06
60+
pos = -0.3
61+
62+
simple_world_mjcf_template = """
63+
<mujoco>
64+
<visual>
65+
<headlight active="0"/>
66+
<global offwidth="%s" offheight="%s"/>
67+
</visual>
68+
<compiler meshdir="%s"/>
69+
<asset>
70+
<mesh name="%s" file="%s.stl" scale="%g %g %g"/>
71+
</asset>
72+
<worldbody>
73+
<camera name="main" pos="0 0 5" xyaxes="1 0 0 0 1 0"/>
74+
<body name="obj" quat="{} {} {} {}">
75+
<geom name="%s" type="mesh" mesh="%s" pos="0 0 %g"/>
76+
</body>
77+
<light pos="{} {} {}" directional="true" dir="{} {} {}"/>
78+
<light pos="{} {} {}" directional="true" dir="{} {} {}"/>
79+
</worldbody>
80+
</mujoco>
81+
""" % (render_width, render_height, meshdir, mesh, mesh,
82+
scale, scale, scale, mesh, mesh, pos)
83+
84+
light /= np.linalg.norm(light)
85+
quat /= np.linalg.norm(quat)
86+
87+
simple_world_mjcf = simple_world_mjcf_template.format(
88+
*(np.concatenate((quat,
89+
5*light, -5*light,
90+
-5*light, 5*light)).tolist()))
91+
physics = mujoco.Physics.from_xml_string(simple_world_mjcf)
92+
data = physics.render(camera_id='main',
93+
height=render_height,
94+
width=render_width).astype(np.float32)
95+
data = data.reshape((width, int(render_width/width),
96+
height, int(render_height/height), 3))
97+
return np.mean(np.mean(data, axis=1), axis=2)
98+
99+
100+
def get_tangent(quat, light, mesh='bunny', meshdir='data',
101+
eps=0.03, use_light=True, use_quat=True):
102+
"""Render image along with its tangent vectors by finite differences."""
103+
assert use_light or use_quat
104+
n = 0
105+
light_tangent = None
106+
quat_tangent = None
107+
if use_light:
108+
light_tangent = get_normal(light)
109+
n += 2
110+
if use_quat:
111+
quat_tangent = get_normal(quat)
112+
n += 3
113+
114+
# a triple-wide pixel buffer
115+
try:
116+
data = render(quat, light, mesh=mesh, meshdir=meshdir)
117+
image_tangent = np.zeros((n, height, width, 3), dtype=np.float32)
118+
119+
if use_quat:
120+
for i in range(3):
121+
perturbed = render(quat + eps * quat_tangent[:, i],
122+
light, mesh=mesh, meshdir=meshdir)
123+
image_tangent[i] = (perturbed - data).astype(np.float32) / eps
124+
125+
if use_light:
126+
j = 3 if use_quat else 0
127+
for i in range(2):
128+
perturbed = render(quat, light + eps * light_tangent[:, i],
129+
mesh=mesh, meshdir=meshdir)
130+
image_tangent[i+j] = (perturbed - data) / eps
131+
132+
image_tangent -= np.mean(image_tangent, axis=0)[None, ...]
133+
latent_tangent = np.block(
134+
[[quat_tangent, np.zeros((4, 2), dtype=np.float32)],
135+
[np.zeros((3, 3), dtype=np.float32), light_tangent]])
136+
137+
return (np.mean(data, axis=-1),
138+
np.mean(image_tangent, axis=-1),
139+
latent_tangent)
140+
except: # pylint: disable=bare-except
141+
logging.info('Failed with latents (quat: %s, light: %s)', quat, light)
142+
143+
144+
def main(_):
145+
images = np.zeros((_SIZE.value, height, width), dtype=np.float32)
146+
latents = np.zeros((_SIZE.value, 7), dtype=np.float32)
147+
148+
image_tangents = np.zeros((_SIZE.value, 5, height, width), dtype=np.float32)
149+
latent_tangents = np.zeros((_SIZE.value, 7, 5), dtype=np.float32)
150+
151+
for i in range(_SIZE.value):
152+
light = np.random.randn(3)
153+
light /= np.linalg.norm(light)
154+
155+
quat = np.random.randn(4) # rotation represented as quaternion
156+
quat /= np.linalg.norm(quat)
157+
158+
latents[i] = np.concatenate((light, quat))
159+
images[i], image_tangents[i], latent_tangents[i] = (
160+
get_tangent(quat, light, mesh=_OBJECT.value, meshdir=_PATH.value))
161+
logging.info('Rendered image %d of %d', i, _SIZE.value)
162+
163+
os.makedirs(os.path.join(_PATH.value, _OBJECT.value), exist_ok=True)
164+
with open(os.path.join(
165+
_PATH.value, _OBJECT.value, 'shard_%03d.npz' % _SHARD.value), 'wb') as f:
166+
io_buffer = io.BytesIO()
167+
np.savez(io_buffer, images, latents, image_tangents, latent_tangents)
168+
f.write(io_buffer.getvalue())
169+
170+
171+
if __name__ == '__main__':
172+
app.run(main)

0 commit comments

Comments
 (0)