Skip to content

Commit e28230c

Browse files
Add API for MJPC cost residuals.
PiperOrigin-RevId: 604774074 Change-Id: I2b4463daf24db079591b58fb56dc3c1eb42ca5f9
1 parent d81174c commit e28230c

File tree

10 files changed

+104
-0
lines changed

10 files changed

+104
-0
lines changed

mjpc/grpc/agent.proto

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ service Agent {
3939
returns (GetTaskParametersResponse);
4040
// Set cost weights.
4141
rpc SetCostWeights(SetCostWeightsRequest) returns (SetCostWeightsResponse);
42+
// Get cost term residuals.
43+
rpc GetResiduals(GetResidualsRequest) returns (GetResidualsResponse);
4244
// Get cost term values.
4345
rpc GetCostValuesAndWeights(GetCostValuesAndWeightsRequest)
4446
returns (GetCostValuesAndWeightsResponse);
@@ -113,6 +115,16 @@ message GetActionResponse {
113115
repeated float action = 1 [packed = true];
114116
}
115117

118+
message GetResidualsRequest {}
119+
120+
message Residual {
121+
repeated double values = 1;
122+
}
123+
124+
message GetResidualsResponse {
125+
map<string, Residual> values = 1;
126+
}
127+
116128
message GetCostValuesAndWeightsRequest {}
117129

118130
message ValueAndWeight {

mjpc/grpc/agent_service.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ using ::agent::GetAllModesRequest;
3535
using ::agent::GetAllModesResponse;
3636
using ::agent::GetBestTrajectoryRequest;
3737
using ::agent::GetBestTrajectoryResponse;
38+
using ::agent::GetResidualsRequest;
39+
using ::agent::GetResidualsResponse;
3840
using ::agent::GetCostValuesAndWeightsRequest;
3941
using ::agent::GetCostValuesAndWeightsResponse;
4042
using ::agent::GetModeRequest;
@@ -181,6 +183,16 @@ grpc::Status AgentService::GetAction(grpc::ServerContext* context,
181183
request, &agent_, model, rollout_data_.get(), &rollout_state_, response);
182184
}
183185

186+
grpc::Status AgentService::GetResiduals(
187+
grpc::ServerContext* context, const GetResidualsRequest* request,
188+
GetResidualsResponse* response) {
189+
if (!Initialized()) {
190+
return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."};
191+
}
192+
return grpc_agent_util::GetResiduals(request, &agent_, model,
193+
data_, response);
194+
}
195+
184196
grpc::Status AgentService::GetCostValuesAndWeights(
185197
grpc::ServerContext* context, const GetCostValuesAndWeightsRequest* request,
186198
GetCostValuesAndWeightsResponse* response) {

mjpc/grpc/agent_service.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class AgentService final : public agent::Agent::Service {
5858
const agent::GetActionRequest* request,
5959
agent::GetActionResponse* response) override;
6060

61+
grpc::Status GetResiduals(
62+
grpc::ServerContext* context,
63+
const agent::GetResidualsRequest* request,
64+
agent::GetResidualsResponse* response) override;
65+
6166
grpc::Status GetCostValuesAndWeights(
6267
grpc::ServerContext* context,
6368
const agent::GetCostValuesAndWeightsRequest* request,

mjpc/grpc/agent_service_test.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,4 +360,16 @@ TEST_F(AgentServiceTest, GetAllModes_Works) {
360360
EXPECT_EQ(response.mode_names()[0], "default_mode");
361361
}
362362

363+
TEST_F(AgentServiceTest, GetResiduals_Works) {
364+
RunAndCheckInit("Cartpole", nullptr);
365+
366+
grpc::ClientContext context;
367+
368+
agent::GetResidualsRequest request;
369+
agent::GetResidualsResponse response;
370+
grpc::Status status = stub->GetResiduals(&context, request, &response);
371+
372+
EXPECT_TRUE(status.ok());
373+
}
374+
363375
} // namespace mjpc::agent_grpc

mjpc/grpc/grpc_agent_util.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ using ::agent::GetActionRequest;
4444
using ::agent::GetActionResponse;
4545
using ::agent::GetAllModesRequest;
4646
using ::agent::GetAllModesResponse;
47+
using ::agent::GetResidualsRequest;
48+
using ::agent::GetResidualsResponse;
4749
using ::agent::GetCostValuesAndWeightsRequest;
4850
using ::agent::GetCostValuesAndWeightsResponse;
4951
using ::agent::GetModeRequest;
@@ -58,6 +60,7 @@ using ::agent::SetCostWeightsRequest;
5860
using ::agent::SetModeRequest;
5961
using ::agent::SetStateRequest;
6062
using ::agent::SetTaskParametersRequest;
63+
using ::agent::Residual;
6164
using ::agent::ValueAndWeight;
6265

6366
grpc::Status GetState(const mjModel* model, const mjData* data,
@@ -226,6 +229,34 @@ grpc::Status GetAction(const GetActionRequest* request,
226229
return grpc::Status::OK;
227230
}
228231

232+
grpc::Status GetResiduals(
233+
const GetResidualsRequest* request, const mjpc::Agent* agent,
234+
const mjModel* model, mjData* data,
235+
GetResidualsResponse* response) {
236+
const mjModel* agent_model = agent->GetModel();
237+
const mjpc::Task* task = agent->ActiveTask();
238+
std::vector<double> residuals(task->num_residual, 0); // scratch space
239+
task->Residual(model, data, residuals.data());
240+
std::vector<int> dim_norm_residual = task->dim_norm_residual;
241+
242+
int residual_shift = 0;
243+
for (int i = 0; i < task->num_term; i++) {
244+
CHECK_EQ(agent_model->sensor_type[i], mjSENS_USER);
245+
std::string_view sensor_name(agent_model->names +
246+
agent_model->name_sensoradr[i]);
247+
248+
std::vector<double> sensor_residual_values(
249+
residuals.begin() + residual_shift,
250+
residuals.begin() + residual_shift + dim_norm_residual[i]);
251+
Residual sensor_residual;
252+
sensor_residual.mutable_values()->Assign(sensor_residual_values.begin(),
253+
sensor_residual_values.end());
254+
(*response->mutable_values())[sensor_name] = sensor_residual;
255+
residual_shift += dim_norm_residual[i];
256+
}
257+
return grpc::Status::OK;
258+
}
259+
229260
grpc::Status GetCostValuesAndWeights(
230261
const GetCostValuesAndWeightsRequest* request, const mjpc::Agent* agent,
231262
const mjModel* model, mjData* data,

mjpc/grpc/grpc_agent_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ grpc::Status GetAction(const agent::GetActionRequest* request,
3434
const mjModel* model, mjData* rollout_data,
3535
mjpc::State* rollout_state,
3636
agent::GetActionResponse* response);
37+
grpc::Status GetResiduals(
38+
const agent::GetResidualsRequest* request,
39+
const mjpc::Agent* agent, const mjModel* model, mjData* data,
40+
agent::GetResidualsResponse* response);
3741
grpc::Status GetCostValuesAndWeights(
3842
const agent::GetCostValuesAndWeightsRequest* request,
3943
const mjpc::Agent* agent, const mjModel* model, mjData* data,

mjpc/grpc/ui_agent_service.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ using ::agent::GetActionRequest;
3737
using ::agent::GetActionResponse;
3838
using ::agent::GetModeRequest;
3939
using ::agent::GetModeResponse;
40+
using ::agent::GetResidualsRequest;
41+
using ::agent::GetResidualsResponse;
4042
using ::agent::GetCostValuesAndWeightsRequest;
4143
using ::agent::GetCostValuesAndWeightsResponse;
4244
using ::agent::GetStateRequest;
@@ -125,6 +127,18 @@ grpc::Status UiAgentService::GetAction(grpc::ServerContext* context,
125127
});
126128
}
127129

130+
grpc::Status UiAgentService::GetResiduals(
131+
grpc::ServerContext* context, const GetResidualsRequest* request,
132+
GetResidualsResponse* response) {
133+
return RunBeforeStep(
134+
context, [request, response](mjpc::Agent* agent, const mjModel* model,
135+
mjData* data) {
136+
return grpc_agent_util::GetResiduals(request, agent, model,
137+
data, response);
138+
});
139+
}
140+
141+
128142
grpc::Status UiAgentService::GetCostValuesAndWeights(
129143
grpc::ServerContext* context, const GetCostValuesAndWeightsRequest* request,
130144
GetCostValuesAndWeightsResponse* response) {

mjpc/grpc/ui_agent_service.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ class UiAgentService final : public agent::Agent::Service {
4949
const agent::GetActionRequest* request,
5050
agent::GetActionResponse* response) override;
5151

52+
grpc::Status GetResiduals(
53+
grpc::ServerContext* context,
54+
const agent::GetResidualsRequest* request,
55+
agent::GetResidualsResponse* response) override;
56+
5257
grpc::Status GetCostValuesAndWeights(
5358
grpc::ServerContext* context,
5459
const agent::GetCostValuesAndWeightsRequest* request,

python/mujoco_mpc/agent.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,11 @@ def get_cost_term_values(self) -> dict[str, float]:
266266
for name, value_weight in terms.values_weights.items()
267267
}
268268

269+
def get_residuals(self) -> dict[str, Sequence[float]]:
270+
residuals = self.stub.GetResiduals(agent_pb2.GetResidualsRequest())
271+
return {name: residual.values
272+
for name, residual in residuals.values.items()}
273+
269274
def planner_step(self):
270275
"""Send a planner request."""
271276
planner_step_request = agent_pb2.PlannerStepRequest()

python/mujoco_mpc/agent_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ def test_get_cost_weights(self):
303303
terms = list(terms_dict.values())
304304
self.assertFalse(np.any(np.isclose(terms, 0, rtol=0, atol=1e-4)))
305305

306+
residuals_dict = agent.get_residuals()
307+
residuals = list(residuals_dict.values())
308+
self.assertFalse(np.any(np.isclose(residuals, 0, rtol=0, atol=1e-4)))
309+
306310
def test_set_state_with_lists(self):
307311
model_path = (
308312
pathlib.Path(__file__).parent.parent.parent

0 commit comments

Comments
 (0)