Skip to content

Commit 9549683

Browse files
authored
Use memcpy in AMSMessage (#106)
Fixes #105 Signed-off-by: Loic Pottier <pottier1@llnl.gov>
1 parent e537da9 commit 9549683

File tree

5 files changed

+48
-21
lines changed

5 files changed

+48
-21
lines changed

src/AMSWorkflow/ams/rmq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _parse_data(self, body: str, header_info: dict) -> Tuple[str, np.array, np.a
163163

164164
idim = header_info["input_dim"]
165165
odim = header_info["output_dim"]
166-
data = data.reshape((-1, idim + odim))
166+
data = data.reshape((idim + odim, -1)).transpose()
167167
# Return input, output
168168
return (domain_name, data[:, :idim], data[:, idim:])
169169

src/AMSlib/wf/basedb.hpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ class AMSMessage
814814
_data(nullptr),
815815
_total_size(0)
816816
{
817+
CALIPER(CALI_MARK_BEGIN("AMS_MESSAGE");)
817818
AMSMsgHeader header(_rank,
818819
domain_name.size(),
819820
_num_elements,
@@ -832,10 +833,11 @@ class AMSMessage
832833
domain_name.size());
833834
current_offset += domain_name.size();
834835
current_offset +=
835-
encode_data(reinterpret_cast<TypeValue*>(_data + current_offset),
836-
inputs,
837-
outputs);
836+
encode_data(_data + current_offset,
837+
inputs,
838+
outputs);
838839
DBG(AMSMessage, "Allocated message %d: %p", _id, _data);
840+
CALIPER(CALI_MARK_END("AMS_MESSAGE");)
839841
}
840842

841843
/**
@@ -881,26 +883,25 @@ class AMSMessage
881883
* @return The number of bytes in the message or 0 if error
882884
*/
883885
template <typename TypeValue>
884-
size_t encode_data(TypeValue* data_blob,
886+
size_t encode_data(uint8_t* data_blob,
885887
const std::vector<TypeValue*>& inputs,
886888
const std::vector<TypeValue*>& outputs)
887889
{
888-
size_t x_dim = _input_dim + _output_dim;
889890
if (!data_blob) return 0;
890-
// Creating the body part of the messages
891-
for (size_t i = 0; i < _num_elements; i++) {
892-
for (size_t j = 0; j < _input_dim; j++) {
893-
data_blob[i * x_dim + j] = inputs[j][i];
894-
}
891+
size_t offset = 0;
892+
893+
// Creating the body part of the message
894+
for (size_t i = 0; i < _input_dim; i++) {
895+
std::memcpy(data_blob + offset, inputs[i], _num_elements * sizeof(TypeValue));
896+
offset += (_num_elements * sizeof(TypeValue));
895897
}
896898

897-
for (size_t i = 0; i < _num_elements; i++) {
898-
for (size_t j = 0; j < _output_dim; j++) {
899-
data_blob[i * x_dim + _input_dim + j] = outputs[j][i];
900-
}
899+
for (size_t i = 0; i < _output_dim; i++) {
900+
std::memcpy(data_blob + offset, outputs[i], _num_elements * sizeof(TypeValue));
901+
offset += (_num_elements * sizeof(TypeValue));
901902
}
902903

903-
return (x_dim * _num_elements) * sizeof(TypeValue);
904+
return ((_input_dim + _output_dim) * _num_elements) * sizeof(TypeValue);
904905
}
905906

906907
/**

tests/AMSlib/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function(INTEGRATION_TEST_RMQ)
7272
endif()
7373
set(JSON_FP "${CMAKE_CURRENT_BINARY_DIR}/rmq.json")
7474
CHECK_RMQ_CONFIG(${JSON_FP})
75-
add_test(NAME AMSEndToEndFromJSON::NoModel::Double::DB::rmq::HOST COMMAND bash -c "AMS_OBJECTS=${JSON_FP} ${CMAKE_CURRENT_BINARY_DIR}/ams_rmq 0 8 9 \"double\" 2 1024; AMS_OBJECTS=${JSON_FP} python3 ${CMAKE_CURRENT_SOURCE_DIR}/verify_rmq.py 0 8 9 \"double\" 2 1024")
75+
add_test(NAME AMSEndToEndFromJSON::NoModel::Double::DB::rmq::HOST COMMAND bash -c "AMS_OBJECTS=${JSON_FP} ${CMAKE_CURRENT_BINARY_DIR}/ams_rmq 0 2 2 \"double\" 2 10; AMS_OBJECTS=${JSON_FP} python3 ${CMAKE_CURRENT_SOURCE_DIR}/verify_rmq.py 0 2 2 \"double\" 2 10")
7676
endif()
7777
endfunction()
7878

tests/AMSlib/verify_rmq.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
from pathlib import Path
44
import os
5+
import numpy as np
56

67
from ams.rmq import BlockingClient, default_ams_callback
78

@@ -37,18 +38,43 @@ def verify(
3738

3839
assert len(msgs) == num_iterations, f"Received incorrect number of messsages ({len(msgs)}): expected #msgs ({num_iterations})"
3940

41+
expected_input = np.array([[0., 0.],
42+
[1., 1.],
43+
[2., 2.],
44+
[3., 3.],
45+
[4., 4.],
46+
[5., 5.],
47+
[6., 6.],
48+
[7., 7.],
49+
[8., 8.],
50+
[9., 9.]]
51+
)
52+
53+
expected_output = np.array([[ 0., 0.],
54+
[ 2., 2.],
55+
[ 4., 4.],
56+
[ 6., 6.],
57+
[ 8., 8.],
58+
[10., 10.],
59+
[12., 12.],
60+
[14., 14.],
61+
[16., 16.],
62+
[18., 18.]]
63+
)
64+
4065
for i, msg in enumerate(msgs):
41-
domain, _, _ = msg.decode()
66+
domain, input_data, output_data = msg.decode()
4267
assert msg.num_elements == num_elements, f"Message #{i}: incorrect #elements ({msg.num_element}) vs. expected #elem {num_elements})"
4368
assert msg.input_dim == num_inputs, f"Message #{i}: incorrect #inputs ({msg.input_dim}) vs. expected #inputs {num_inputs})"
4469
assert msg.output_dim == num_outputs, f"Message #{i}: incorrect #outputs ({msg.output_dim}) vs. expected #outputs {num_outputs})"
4570
assert msg.dtype_byte == dtype, f"Message #{i}: incorrect datatype ({msg.dtype_byte} bytes) vs. expected type {dtype} bytes)"
4671
assert domain == domain_test, f"Message #{i}: incorrect domain name (got {domain}) expected rmq_db_no_model)"
72+
assert np.array_equal(input_data, expected_input), f"Message #{i}: incorrect incorrect input data"
73+
assert np.array_equal(output_data, expected_output), f"Message #{i}: incorrect incorrect output data"
4774

4875
return 0
4976

5077
def from_json(argv):
51-
print(argv)
5278
use_device = int(argv[0])
5379
num_inputs = int(argv[1])
5480
num_outputs = int(argv[2])
@@ -72,7 +98,7 @@ def from_json(argv):
7298
num_iterations,
7399
num_elements,
74100
rmq_json["db"]["rmq_config"],
75-
timeout = 60 # in seconds
101+
timeout = 20 # in seconds
76102
)
77103
if res != 0:
78104
return res

tools/rmq/recv_binary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def parse_data(body: str, header_info: dict) -> Tuple[str, np.array, np.array]:
100100

101101
idim = header_info["input_dim"]
102102
odim = header_info["output_dim"]
103-
data = data.reshape((-1, idim + odim))
103+
data = data.reshape((idim + odim, -1)).transpose()
104104
# Return input, output
105105
return (domain_name, data[:, :idim], data[:, idim:])
106106

0 commit comments

Comments
 (0)