Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _parse_data(self, body: str, header_info: dict) -> Tuple[str, np.array, np.a

idim = header_info["input_dim"]
odim = header_info["output_dim"]
data = data.reshape((-1, idim + odim))
data = data.reshape((idim + odim, -1)).transpose()
# Return input, output
return (domain_name, data[:, :idim], data[:, idim:])

Expand Down
31 changes: 16 additions & 15 deletions src/AMSlib/wf/basedb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ class AMSMessage
_data(nullptr),
_total_size(0)
{
CALIPER(CALI_MARK_BEGIN("AMS_MESSAGE");)
AMSMsgHeader header(_rank,
domain_name.size(),
_num_elements,
Expand All @@ -832,10 +833,11 @@ class AMSMessage
domain_name.size());
current_offset += domain_name.size();
current_offset +=
encode_data(reinterpret_cast<TypeValue*>(_data + current_offset),
inputs,
outputs);
encode_data(_data + current_offset,
inputs,
outputs);
DBG(AMSMessage, "Allocated message %d: %p", _id, _data);
CALIPER(CALI_MARK_END("AMS_MESSAGE");)
}

/**
Expand Down Expand Up @@ -881,26 +883,25 @@ class AMSMessage
* @return The number of bytes in the message or 0 if error
*/
template <typename TypeValue>
size_t encode_data(TypeValue* data_blob,
size_t encode_data(uint8_t* data_blob,
const std::vector<TypeValue*>& inputs,
const std::vector<TypeValue*>& outputs)
{
size_t x_dim = _input_dim + _output_dim;
if (!data_blob) return 0;
// Creating the body part of the messages
for (size_t i = 0; i < _num_elements; i++) {
for (size_t j = 0; j < _input_dim; j++) {
data_blob[i * x_dim + j] = inputs[j][i];
}
size_t offset = 0;

// Creating the body part of the message
for (size_t i = 0; i < _input_dim; i++) {
std::memcpy(data_blob + offset, inputs[i], _num_elements * sizeof(TypeValue));
offset += (_num_elements * sizeof(TypeValue));
}

for (size_t i = 0; i < _num_elements; i++) {
for (size_t j = 0; j < _output_dim; j++) {
data_blob[i * x_dim + _input_dim + j] = outputs[j][i];
}
for (size_t i = 0; i < _output_dim; i++) {
std::memcpy(data_blob + offset, outputs[i], _num_elements * sizeof(TypeValue));
offset += (_num_elements * sizeof(TypeValue));
}

return (x_dim * _num_elements) * sizeof(TypeValue);
return ((_input_dim + _output_dim) * _num_elements) * sizeof(TypeValue);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion tests/AMSlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function(INTEGRATION_TEST_RMQ)
endif()
set(JSON_FP "${CMAKE_CURRENT_BINARY_DIR}/rmq.json")
CHECK_RMQ_CONFIG(${JSON_FP})
add_test(NAME AMSEndToEndFromJSON::NoModel::Double::DB::rmq::HOST COMMAND bash -c "AMS_OBJECTS=${JSON_FP} ${CMAKE_CURRENT_BINARY_DIR}/ams_rmq 0 8 9 \"double\" 2 1024; AMS_OBJECTS=${JSON_FP} python3 ${CMAKE_CURRENT_SOURCE_DIR}/verify_rmq.py 0 8 9 \"double\" 2 1024")
add_test(NAME AMSEndToEndFromJSON::NoModel::Double::DB::rmq::HOST COMMAND bash -c "AMS_OBJECTS=${JSON_FP} ${CMAKE_CURRENT_BINARY_DIR}/ams_rmq 0 2 2 \"double\" 2 10; AMS_OBJECTS=${JSON_FP} python3 ${CMAKE_CURRENT_SOURCE_DIR}/verify_rmq.py 0 2 2 \"double\" 2 10")
endif()
endfunction()

Expand Down
32 changes: 29 additions & 3 deletions tests/AMSlib/verify_rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from pathlib import Path
import os
import numpy as np

from ams.rmq import BlockingClient, default_ams_callback

Expand Down Expand Up @@ -37,18 +38,43 @@ def verify(

assert len(msgs) == num_iterations, f"Received incorrect number of messsages ({len(msgs)}): expected #msgs ({num_iterations})"

expected_input = np.array([[0., 0.],
[1., 1.],
[2., 2.],
[3., 3.],
[4., 4.],
[5., 5.],
[6., 6.],
[7., 7.],
[8., 8.],
[9., 9.]]
)

expected_output = np.array([[ 0., 0.],
[ 2., 2.],
[ 4., 4.],
[ 6., 6.],
[ 8., 8.],
[10., 10.],
[12., 12.],
[14., 14.],
[16., 16.],
[18., 18.]]
)

for i, msg in enumerate(msgs):
domain, _, _ = msg.decode()
domain, input_data, output_data = msg.decode()
assert msg.num_elements == num_elements, f"Message #{i}: incorrect #elements ({msg.num_element}) vs. expected #elem {num_elements})"
assert msg.input_dim == num_inputs, f"Message #{i}: incorrect #inputs ({msg.input_dim}) vs. expected #inputs {num_inputs})"
assert msg.output_dim == num_outputs, f"Message #{i}: incorrect #outputs ({msg.output_dim}) vs. expected #outputs {num_outputs})"
assert msg.dtype_byte == dtype, f"Message #{i}: incorrect datatype ({msg.dtype_byte} bytes) vs. expected type {dtype} bytes)"
assert domain == domain_test, f"Message #{i}: incorrect domain name (got {domain}) expected rmq_db_no_model)"
assert np.array_equal(input_data, expected_input), f"Message #{i}: incorrect incorrect input data"
assert np.array_equal(output_data, expected_output), f"Message #{i}: incorrect incorrect output data"

return 0

def from_json(argv):
print(argv)
use_device = int(argv[0])
num_inputs = int(argv[1])
num_outputs = int(argv[2])
Expand All @@ -72,7 +98,7 @@ def from_json(argv):
num_iterations,
num_elements,
rmq_json["db"]["rmq_config"],
timeout = 60 # in seconds
timeout = 20 # in seconds
)
if res != 0:
return res
Expand Down
2 changes: 1 addition & 1 deletion tools/rmq/recv_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def parse_data(body: str, header_info: dict) -> Tuple[str, np.array, np.array]:

idim = header_info["input_dim"]
odim = header_info["output_dim"]
data = data.reshape((-1, idim + odim))
data = data.reshape((idim + odim, -1)).transpose()
# Return input, output
return (domain_name, data[:, :idim], data[:, idim:])

Expand Down