Skip to content

Commit 46072a6

Browse files
committed
renamed SoAMetadata class to SoAWrapper
1 parent a587c05 commit 46072a6

File tree

11 files changed

+44
-42
lines changed

11 files changed

+44
-42
lines changed

PhysicsTools/PyTorch/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ The interface provides a converter to dynamically wrap SoA data into one or more
6464

6565
#### Metadata
6666

67-
The structual information of the input and output SoA are stored in an `SoAMetadata`. These two objects are then combined to a `ModelMetadata`, to be used by the `Converter`.
67+
The structual information of the input and output SoA are stored in an `SoAWrapper`. These two objects are then combined to a `ModelMetadata`, to be used by the `Converter`.
6868

6969
#### Defining Metadata
7070

71-
The `SoAMetadata' can be defined by first initialising the object and then adding blocks to the metadata. Each block is transformed into a tensor whose size and type are derived from the columns provided.
71+
The `SoAWrapper' can be defined by first initialising the object and then adding blocks to the metadata. Each block is transformed into a tensor whose size and type are derived from the columns provided.
7272

7373
Example SOA Template for Model Input:
7474
```cpp
@@ -97,14 +97,14 @@ fill(queue, deviceCollection);
9797
auto records = deviceCollection.view().records();
9898
auto result_records = deviceResultCollection.view().records();
9999

100-
SoAMetadata<SoA> input(batch_size);
100+
SoAWrapper<SoA> input(batch_size);
101101
input.append_block("eigen_vector", records.a(), records.b());
102102
input.append_block("eigen_matrix", records.c());
103103
input.append_block("column", records.x(), records.y(), records.z());
104104
input.append_block("scalar", view.type());
105105
input.change_order({"column", "scalar", "eigen_matrix", "eigen_vector"});
106106

107-
SoAMetadata<SoA> output(batch_size);
107+
SoAWrapper<SoA> output(batch_size);
108108
output.append_block("result", result_view.cluster());
109109
ModelMetadata metadata(input, output);
110110
```
@@ -150,9 +150,9 @@ alpaka::wait(queue);
150150
// metadata for automatic tensor conversion
151151
auto input_records = inputs_device.view().records();
152152
auto output_records = outputs_device.view().records();
153-
cms::torch::alpaka::SoAMetadata<SoAInputs> inputs_metadata(batch_size);
153+
cms::torch::alpaka::SoAWrapper<SoAInputs> inputs_metadata(batch_size);
154154
inputs_metadata.append_block("features", input_records.x(), input_records.y(), input_records.z());
155-
cms::torch::alpaka::SoAMetadata<SoAOutputs> outputs_metadata(batch_size);
155+
cms::torch::alpaka::SoAWrapper<SoAOutputs> outputs_metadata(batch_size);
156156
outputs_metadata.append_block("preds", output_records.m(), output_records.n());
157157
cms::torch::alpaka::ModelMetadata<SoAInputs, SoAOutputs> metadata(inputs_metadata, outputs_metadata);
158158
@@ -180,4 +180,5 @@ The function `change_order()` in the allows specifying the order in which the bl
180180
- #9786 Pytorch with ROCm (temp): https://github.com/cms-sw/cmsdist/pull/9312
181181
- #9312 [WIP] Build PyTorch with ROCm https://github.com/cms-sw/cmsdist/pull/9786
182182
- AOT support is under active development and subject to changes that obey CMSSW releasing rules.
183-
- On more complex models with multiple output branches extra copy is needed
183+
- On more complex models with multiple output branches extra copy is needed
184+
- The alignment of the SoA has to be a multiple of the types used, as no bytewise padding is currently supported.

PhysicsTools/PyTorch/interface/Converter.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <torch/torch.h>
55

66
#include "DataFormats/SoATemplate/interface/SoALayout.h"
7-
#include "PhysicsTools/PyTorch/interface/SoAMetadata.h"
7+
#include "PhysicsTools/PyTorch/interface/SoAWrapper.h"
88

99

1010
namespace cms::torch::alpaka {
@@ -13,14 +13,14 @@ namespace cms::torch::alpaka {
1313
template <typename SOA_Input, typename SOA_Output>
1414
class ModelMetadata {
1515
public:
16-
SoAMetadata<SOA_Input> input;
17-
SoAMetadata<SOA_Output> output;
16+
SoAWrapper<SOA_Input> input;
17+
SoAWrapper<SOA_Output> output;
1818

1919
// Used in AOT model class to correctly choose multi or single output conversion
2020
// Default value true, as single value can be parsed with multi output
2121
bool multi_output;
2222

23-
ModelMetadata(const SoAMetadata<SOA_Input>& input_, const SoAMetadata<SOA_Output>& output_, bool multi_output_=true)
23+
ModelMetadata(const SoAWrapper<SOA_Input>& input_, const SoAWrapper<SOA_Output>& output_, bool multi_output_=true)
2424
: input(input_), output(output_), multi_output(multi_output_) {}
2525
};
2626

PhysicsTools/PyTorch/interface/SoAMetadata.h renamed to PhysicsTools/PyTorch/interface/SoAWrapper.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef PHYSICS_TOOLS__PYTORCH__INTERFACE__SOAMETADATA_H_
2-
#define PHYSICS_TOOLS__PYTORCH__INTERFACE__SOAMETADATA_H_
1+
#ifndef PHYSICS_TOOLS__PYTORCH__INTERFACE__SoAWrapper_H_
2+
#define PHYSICS_TOOLS__PYTORCH__INTERFACE__SoAWrapper_H_
33

44
#include <iostream>
55

@@ -107,7 +107,7 @@ struct Block {
107107
// Metadata for SOA split into multiple blocks.
108108
// An order for the resulting tensors can be defined.
109109
template <typename SOA_Layout>
110-
struct SoAMetadata {
110+
struct SoAWrapper {
111111
private:
112112
std::map<std::string, Block<SOA_Layout>> blocks;
113113

@@ -145,7 +145,7 @@ struct SoAMetadata {
145145
int nElements;
146146
int nBlocks;
147147

148-
SoAMetadata(int nElements_) : nElements(nElements_), nBlocks(0) {}
148+
SoAWrapper(int nElements_) : nElements(nElements_), nBlocks(0) {}
149149

150150
// TODO: Check columns are contiguous
151151
template <typename T, typename... Others>
@@ -198,4 +198,4 @@ struct SoAMetadata {
198198

199199
} // namespace cms::torch::alpaka
200200

201-
#endif // PHYSICS_TOOLS__PYTORCH__INTERFACE__SOAMETADATA_H_
201+
#endif // PHYSICS_TOOLS__PYTORCH__INTERFACE__SoAWrapper_H_

PhysicsTools/PyTorch/plugins/alpaka/AotRegressionProducer.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "HeterogeneousCore/AlpakaInterface/interface/config.h"
1616
#include "PhysicsTools/PyTorch/interface/AlpakaConfig.h"
1717
#include "PhysicsTools/PyTorch/interface/Model.h"
18-
#include "PhysicsTools/PyTorch/interface/SoAMetadata.h"
18+
#include "PhysicsTools/PyTorch/interface/SoAWrapper.h"
1919
#include "PhysicsTools/PyTorch/interface/Nvtx.h"
2020
#include "PhysicsTools/PyTorch/plugins/alpaka/Kernels.h"
2121

@@ -96,9 +96,9 @@ void AotRegressionProducer::produce(device::Event &event, const device::EventSet
9696
// metadata for automatic tensor conversion
9797
auto input_records = inputs.view().records();
9898
auto output_records = outputs.view().records();
99-
cms::torch::alpaka::SoAMetadata<torchportable::ParticleSoA> inputs_metadata(batch_size);
99+
cms::torch::alpaka::SoAWrapper<torchportable::ParticleSoA> inputs_metadata(batch_size);
100100
inputs_metadata.append_block("features", input_records.pt(), input_records.eta(), input_records.phi());
101-
cms::torch::alpaka::SoAMetadata<torchportable::RegressionSoA> outputs_metadata(batch_size);
101+
cms::torch::alpaka::SoAWrapper<torchportable::RegressionSoA> outputs_metadata(batch_size);
102102
outputs_metadata.append_block("preds", output_records.reco_pt());
103103
cms::torch::alpaka::ModelMetadata<torchportable::ParticleSoA, torchportable::RegressionSoA> metadata(inputs_metadata, outputs_metadata);
104104

PhysicsTools/PyTorch/plugins/alpaka/JitClassificationProducer.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "HeterogeneousCore/AlpakaInterface/interface/config.h"
1616
#include "PhysicsTools/PyTorch/interface/AlpakaConfig.h"
1717
#include "PhysicsTools/PyTorch/interface/Model.h"
18-
#include "PhysicsTools/PyTorch/interface/SoAMetadata.h"
18+
#include "PhysicsTools/PyTorch/interface/SoAWrapper.h"
1919
#include "PhysicsTools/PyTorch/interface/Nvtx.h"
2020
#include "PhysicsTools/PyTorch/plugins/alpaka/Kernels.h"
2121

@@ -96,9 +96,9 @@ void JitClassificationProducer::produce(device::Event &event, const device::Even
9696
// metadata for automatic tensor conversion
9797
auto input_records = inputs.view().records();
9898
auto output_records = outputs.view().records();
99-
cms::torch::alpaka::SoAMetadata<torchportable::ParticleSoA> inputs_metadata(batch_size);
99+
cms::torch::alpaka::SoAWrapper<torchportable::ParticleSoA> inputs_metadata(batch_size);
100100
inputs_metadata.append_block("features", input_records.pt(), input_records.eta(), input_records.phi());
101-
cms::torch::alpaka::SoAMetadata<torchportable::ClassificationSoA> outputs_metadata(batch_size);
101+
cms::torch::alpaka::SoAWrapper<torchportable::ClassificationSoA> outputs_metadata(batch_size);
102102
outputs_metadata.append_block("preds", output_records.c1(), output_records.c2());
103103
cms::torch::alpaka::ModelMetadata<torchportable::ParticleSoA, torchportable::ClassificationSoA> metadata(inputs_metadata, outputs_metadata);
104104

PhysicsTools/PyTorch/test/BuildFile.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
</bin>
116116

117117
<bin file="testModel.cc testRunner.cc" name="testModel">
118+
<use name="boost_filesystem"/>
118119
<use name="cppunit"/>
119120
<use name="pytorch"/>
120121
<use name="pytorch-cuda"/>

PhysicsTools/PyTorch/test/alpaka/testPortableInferenceAOT.dev.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
1313
#include "PhysicsTools/PyTorch/interface/AlpakaConfig.h"
1414
#include "PhysicsTools/PyTorch/interface/Model.h"
15-
#include "PhysicsTools/PyTorch/interface/SoAMetadata.h"
15+
#include "PhysicsTools/PyTorch/interface/SoAWrapper.h"
1616
#include "PhysicsTools/PyTorch/test/testUtilities.h"
1717

1818

@@ -97,9 +97,9 @@ void TestPortableInferenceAOT::test() {
9797
// metadata for automatic tensor conversion
9898
auto input_records = inputs_device.view().records();
9999
auto output_records = outputs_device.view().records();
100-
cms::torch::alpaka::SoAMetadata<SoAInputs> inputs_metadata(batch_size);
100+
cms::torch::alpaka::SoAWrapper<SoAInputs> inputs_metadata(batch_size);
101101
inputs_metadata.append_block("features", input_records.x(), input_records.y(), input_records.z());
102-
cms::torch::alpaka::SoAMetadata<SoAOutputs> outputs_metadata(batch_size);
102+
cms::torch::alpaka::SoAWrapper<SoAOutputs> outputs_metadata(batch_size);
103103
outputs_metadata.append_block("prob", output_records.prob());
104104
cms::torch::alpaka::ModelMetadata<SoAInputs, SoAOutputs> metadata(inputs_metadata, outputs_metadata);
105105
// inference

PhysicsTools/PyTorch/test/alpaka/testPortableInferenceJIT.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
1313
#include "PhysicsTools/PyTorch/interface/AlpakaConfig.h"
1414
#include "PhysicsTools/PyTorch/interface/Model.h"
15-
#include "PhysicsTools/PyTorch/interface/SoAMetadata.h"
15+
#include "PhysicsTools/PyTorch/interface/SoAWrapper.h"
1616
#include "PhysicsTools/PyTorch/test/testUtilities.h"
1717

1818

@@ -84,9 +84,9 @@ void TestPortableInferenceJIT::test() {
8484
// metadata for automatic tensor conversion
8585
auto input_records = inputs_device.view().records();
8686
auto output_records = outputs_device.view().records();
87-
cms::torch::alpaka::SoAMetadata<SoAInputs> inputs_metadata(batch_size);
87+
cms::torch::alpaka::SoAWrapper<SoAInputs> inputs_metadata(batch_size);
8888
inputs_metadata.append_block("features", input_records.x(), input_records.y(), input_records.z());
89-
cms::torch::alpaka::SoAMetadata<SoAOutputs> outputs_metadata(batch_size);
89+
cms::torch::alpaka::SoAWrapper<SoAOutputs> outputs_metadata(batch_size);
9090
outputs_metadata.append_block("preds", output_records.m(), output_records.n());
9191
cms::torch::alpaka::ModelMetadata<SoAInputs, SoAOutputs> metadata(inputs_metadata, outputs_metadata);
9292
// inference

PhysicsTools/PyTorch/test/alpaka/testSOADataTypes.dev.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,14 @@ void TestSOADataTypes::testInterfaceVerbose() {
254254
fill(queue, deviceCollection);
255255
SoAMetaRecords records = deviceCollection.view().records();
256256

257-
SoAMetadata<SoA> input(batch_size);
257+
SoAWrapper<SoA> input(batch_size);
258258
input.append_block("vector", records.a(), records.b());
259259
input.append_block("matrix", records.c());
260260
input.append_block("column", records.x(), records.y(), records.z());
261261
input.append_block("scalar", records.type());
262262
input.change_order({"column", "scalar", "matrix", "vector"});
263263

264-
SoAMetadata<SoA> output(batch_size);
264+
SoAWrapper<SoA> output(batch_size);
265265
output.append_block("result", records.v());
266266
ModelMetadata metadata(input, output);
267267

@@ -289,11 +289,11 @@ void TestSOADataTypes::testMultiOutput() {
289289
fill(queue, deviceCollection);
290290

291291
auto records = deviceCollection.view().records();
292-
SoAMetadata<SoA> input(batch_size);
292+
SoAWrapper<SoA> input(batch_size);
293293
input.append_block("x", records.x());
294294
input.append_block("y", records.y());
295295

296-
SoAMetadata<SoA> output(batch_size);
296+
SoAWrapper<SoA> output(batch_size);
297297
output.append_block("v", records.v());
298298
output.append_block("w", records.w());
299299
ModelMetadata metadata(input, output);
@@ -322,14 +322,14 @@ void TestSOADataTypes::testSingleElement() {
322322
SoAMetaRecords records = deviceCollection.view().records();
323323

324324
// Run Converter for single tensor
325-
SoAMetadata<SoA> input(batch_size);
325+
SoAWrapper<SoA> input(batch_size);
326326
input.append_block("vector", records.a(), records.b());
327327
input.append_block("matrix", records.c());
328328
input.append_block("column", records.x(), records.y(), records.z());
329329
input.append_block("scalar", records.type());
330330
input.change_order({"column", "scalar", "matrix", "vector"});
331331

332-
SoAMetadata<SoA> output(batch_size);
332+
SoAWrapper<SoA> output(batch_size);
333333
output.append_block("result", records.v());
334334
ModelMetadata metadata(input, output);
335335

@@ -355,14 +355,14 @@ void TestSOADataTypes::testNoElement() {
355355
SoAMetaRecords records = deviceCollection.view().records();
356356

357357
// Run Converter
358-
SoAMetadata<SoA> input(batch_size);
358+
SoAWrapper<SoA> input(batch_size);
359359
input.append_block("vector", records.a(), records.b());
360360
input.append_block("matrix", records.c());
361361
input.append_block("column", records.x(), records.y(), records.z());
362362
input.append_block("scalar", records.type());
363363
input.change_order({"column", "scalar", "matrix", "vector"});
364364

365-
SoAMetadata<SoA> output(batch_size);
365+
SoAWrapper<SoA> output(batch_size);
366366
output.append_block("result", records.v());
367367
ModelMetadata metadata(input, output);
368368

@@ -392,8 +392,8 @@ void TestSOADataTypes::testEmptyMetadata() {
392392
fill(queue, deviceCollection);
393393

394394
// Run Converter for empty metadata
395-
SoAMetadata<SoA> input(batch_size);
396-
SoAMetadata<SoA> output(batch_size);
395+
SoAWrapper<SoA> input(batch_size);
396+
SoAWrapper<SoA> output(batch_size);
397397
ModelMetadata metadata(input, output);
398398

399399
alpaka::wait(queue);

PhysicsTools/PyTorch/test/alpaka/testSOAtoTorch.dev.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ void testSOAToTorch::test() {
125125
}
126126

127127
// Create SoA Metadata
128-
cms::torch::alpaka::SoAMetadata<SoAPosition> input(batch_size);
128+
cms::torch::alpaka::SoAWrapper<SoAPosition> input(batch_size);
129129
auto posview = positionCollection.view().records();
130130
input.append_block("main", posview.x(), posview.y(), posview.z());
131131

132-
cms::torch::alpaka::SoAMetadata<SoAResult> output(batch_size);
132+
cms::torch::alpaka::SoAWrapper<SoAResult> output(batch_size);
133133
auto view = resultCollection.view().records();
134134
output.append_block("result", view.x(), view.y());
135135
cms::torch::alpaka::ModelMetadata metadata(input, output);

0 commit comments

Comments
 (0)