Skip to content

Commit 10477ee

Browse files
ai-edge-botcopybara-github
authored andcommitted
Refactor python wrapper
LiteRT-PiperOrigin-RevId: 775363263
1 parent bfb1db7 commit 10477ee

File tree

7 files changed

+114
-125
lines changed

7 files changed

+114
-125
lines changed

litert/python/litert_wrapper/common/litert_wrapper_utils.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
#include "litert/c/litert_tensor_buffer.h"
2121
#include "litert/cc/litert_tensor_buffer.h"
2222

23-
namespace litert {
24-
namespace litert_wrapper_utils {
23+
namespace litert::litert_wrapper_utils {
2524

2625
void DestroyTensorBufferFromCapsule(PyObject* capsule) {
2726
// TODO(b/414622532): Remove this check, using PyCapsule_GetPointer default
@@ -37,10 +36,9 @@ void DestroyTensorBufferFromCapsule(PyObject* capsule) {
3736
}
3837
}
3938

40-
PyObject* MakeTensorBufferCapsule(litert::TensorBuffer& buffer) {
39+
PyObject* MakeTensorBufferCapsule(TensorBuffer& buffer) {
4140
return PyCapsule_New(buffer.Release(), kLiteRtTensorBufferName.data(),
4241
&DestroyTensorBufferFromCapsule);
4342
}
4443

45-
} // namespace litert_wrapper_utils
46-
} // namespace litert
44+
} // namespace litert::litert_wrapper_utils

litert/python/litert_wrapper/common/litert_wrapper_utils.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
#include "absl/strings/string_view.h" // from @com_google_absl
2121
#include "litert/cc/litert_tensor_buffer.h"
2222

23-
namespace litert {
24-
namespace litert_wrapper_utils {
23+
namespace litert::litert_wrapper_utils {
2524

2625
// The name used for LiteRtTensorBuffer capsules
2726
constexpr absl::string_view kLiteRtTensorBufferName = "LiteRtTensorBuffer";
@@ -30,10 +29,9 @@ constexpr absl::string_view kLiteRtTensorBufferName = "LiteRtTensorBuffer";
3029
// to prevent double destruction. Returns true if successful.
3130
void DestroyTensorBufferFromCapsule(PyObject* capsule);
3231

33-
// Creates a PyCapsule for a TensorBuffer with appropriate destructor.
34-
PyObject* MakeTensorBufferCapsule(litert::TensorBuffer& buffer);
32+
// Creates a PyCapsule for a TensorBuffer with the appropriate destructor.
33+
PyObject* MakeTensorBufferCapsule(TensorBuffer& buffer);
3534

36-
} // namespace litert_wrapper_utils
37-
} // namespace litert
35+
} // namespace litert::litert_wrapper_utils
3836

3937
#endif // LITERT_PYTHON_LITERT_WRAPPER_COMMON_LITERT_WRAPPER_UTILS_H_

litert/python/litert_wrapper/compiled_model_wrapper/compiled_model_wrapper.cc

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@
3535
#include "litert/cc/litert_tensor_buffer.h"
3636
#include "litert/python/litert_wrapper/common/litert_wrapper_utils.h"
3737

38-
namespace litert {
39-
namespace compiled_model_wrapper {
38+
namespace litert::compiled_model_wrapper {
4039

4140
// Returns the byte width of a data type.
4241
size_t CompiledModelWrapper::ByteWidthOfDType(const std::string& dtype) {
@@ -47,9 +46,8 @@ size_t CompiledModelWrapper::ByteWidthOfDType(const std::string& dtype) {
4746
}
4847

4948
// Constructor for CompiledModelWrapper.
50-
CompiledModelWrapper::CompiledModelWrapper(litert::Environment env,
51-
litert::Model model,
52-
litert::CompiledModel compiled)
49+
CompiledModelWrapper::CompiledModelWrapper(Environment env, Model model,
50+
CompiledModel compiled)
5351
: environment_(std::move(env)),
5452
model_(std::move(model)),
5553
compiled_model_(std::move(compiled)) {}
@@ -71,8 +69,7 @@ PyObject* CompiledModelWrapper::ReportError(const std::string& msg) {
7169
}
7270

7371
// Converts a LiteRT error to a Python exception.
74-
PyObject* CompiledModelWrapper::ConvertErrorToPyExc(
75-
const litert::Error& error) {
72+
PyObject* CompiledModelWrapper::ConvertErrorToPyExc(const Error& error) {
7673
PyErr_Format(PyExc_RuntimeError, "CompiledModel error: code=%d, message=%s",
7774
static_cast<int>(error.Status()), error.Message().c_str());
7875
return nullptr;
@@ -83,36 +80,36 @@ CompiledModelWrapper* CompiledModelWrapper::CreateWrapperFromFile(
8380
const char* model_path, const char* compiler_plugin_path,
8481
const char* dispatch_library_path, int hardware_accel,
8582
std::string* out_error) {
86-
// Create environment with options
87-
std::vector<litert::Environment::Option> options;
83+
// Create an environment with options
84+
std::vector<Environment::Option> options;
8885
if (compiler_plugin_path && *compiler_plugin_path) {
89-
options.push_back(litert::Environment::Option{
90-
litert::Environment::OptionTag::CompilerPluginLibraryDir,
91-
std::string(compiler_plugin_path)});
86+
options.push_back(
87+
Environment::Option{Environment::OptionTag::CompilerPluginLibraryDir,
88+
std::string(compiler_plugin_path)});
9289
}
9390
if (dispatch_library_path && *dispatch_library_path) {
94-
options.push_back(litert::Environment::Option{
95-
litert::Environment::OptionTag::DispatchLibraryDir,
96-
std::string(dispatch_library_path)});
91+
options.push_back(
92+
Environment::Option{Environment::OptionTag::DispatchLibraryDir,
93+
std::string(dispatch_library_path)});
9794
}
98-
auto env_or = litert::Environment::Create(options);
95+
auto env_or = Environment::Create(options);
9996
if (!env_or) {
10097
if (out_error) *out_error = env_or.Error().Message();
10198
return nullptr;
10299
}
103-
litert::Environment env = std::move(*env_or);
100+
Environment env = std::move(*env_or);
104101

105-
// Load model from file
106-
auto model_or = litert::Model::CreateFromFile(model_path);
102+
// Load model from a file
103+
auto model_or = Model::CreateFromFile(model_path);
107104
if (!model_or) {
108105
if (out_error) *out_error = model_or.Error().Message();
109106
return nullptr;
110107
}
111-
litert::Model model = std::move(*model_or);
108+
Model model = std::move(*model_or);
112109

113-
// Create compiled model
114-
auto compiled_or = litert::CompiledModel::Create(
115-
env, model, (LiteRtHwAccelerators)hardware_accel);
110+
// Create a compiled model
111+
auto compiled_or = CompiledModel::Create(
112+
env, model, static_cast<LiteRtHwAccelerators>(hardware_accel));
116113
if (!compiled_or) {
117114
if (out_error) *out_error = compiled_or.Error().Message();
118115
return nullptr;
@@ -145,37 +142,37 @@ CompiledModelWrapper* CompiledModelWrapper::CreateWrapperFromBuffer(
145142
}
146143

147144
// Create environment with options
148-
std::vector<litert::Environment::Option> options;
145+
std::vector<Environment::Option> options;
149146
if (compiler_plugin_path && *compiler_plugin_path) {
150-
options.push_back(litert::Environment::Option{
151-
litert::Environment::OptionTag::CompilerPluginLibraryDir,
152-
std::string(compiler_plugin_path)});
147+
options.push_back(
148+
Environment::Option{Environment::OptionTag::CompilerPluginLibraryDir,
149+
std::string(compiler_plugin_path)});
153150
}
154151
if (dispatch_library_path && *dispatch_library_path) {
155-
options.push_back(litert::Environment::Option{
156-
litert::Environment::OptionTag::DispatchLibraryDir,
157-
std::string(dispatch_library_path)});
152+
options.push_back(
153+
Environment::Option{Environment::OptionTag::DispatchLibraryDir,
154+
std::string(dispatch_library_path)});
158155
}
159156

160-
auto env_or = litert::Environment::Create(options);
157+
auto env_or = Environment::Create(options);
161158
if (!env_or) {
162159
if (out_error) *out_error = env_or.Error().Message();
163160
return nullptr;
164161
}
165-
litert::Environment env = std::move(*env_or);
162+
Environment env = std::move(*env_or);
166163

167164
// Create model from buffer
168-
litert::BufferRef<uint8_t> ref(reinterpret_cast<uint8_t*>(buf),
169-
static_cast<size_t>(length));
170-
auto model_or = litert::Model::CreateFromBuffer(ref);
165+
BufferRef<uint8_t> ref(reinterpret_cast<uint8_t*>(buf),
166+
static_cast<size_t>(length));
167+
auto model_or = Model::CreateFromBuffer(ref);
171168
if (!model_or) {
172169
if (out_error) *out_error = model_or.Error().Message();
173170
return nullptr;
174171
}
175-
litert::Model model = std::move(*model_or);
172+
Model model = std::move(*model_or);
176173

177-
// Create compiled model
178-
auto compiled_or = litert::CompiledModel::Create(
174+
// Create a compiled model
175+
auto compiled_or = CompiledModel::Create(
179176
env, model, static_cast<LiteRtHwAccelerators>(hardware_accel));
180177
if (!compiled_or) {
181178
if (out_error) *out_error = compiled_or.Error().Message();
@@ -266,7 +263,7 @@ PyObject* CompiledModelWrapper::GetSignatureByIndex(int signature_index) {
266263
// Returns the number of signatures in the model.
267264
PyObject* CompiledModelWrapper::GetNumSignatures() {
268265
auto num = model_.GetNumSignatures();
269-
return PyLong_FromLong((int64_t)num);
266+
return PyLong_FromLong(static_cast<int64_t>(num));
270267
}
271268

272269
// Returns the index of a signature by key.
@@ -296,7 +293,8 @@ PyObject* CompiledModelWrapper::GetInputBufferRequirements(int signature_index,
296293
Py_DECREF(dict);
297294
return ConvertErrorToPyExc(size_or.Error());
298295
}
299-
PyDict_SetItemString(dict, "buffer_size", PyLong_FromLong((int64_t)*size_or));
296+
PyDict_SetItemString(dict, "buffer_size",
297+
PyLong_FromLong(static_cast<int64_t>(*size_or)));
300298

301299
// Add supported types
302300
auto types_or = req.SupportedTypes();
@@ -305,7 +303,7 @@ PyObject* CompiledModelWrapper::GetInputBufferRequirements(int signature_index,
305303
return ConvertErrorToPyExc(types_or.Error());
306304
}
307305
auto types = std::move(*types_or);
308-
PyObject* py_list = PyList_New((Py_ssize_t)types.size());
306+
PyObject* py_list = PyList_New(static_cast<Py_ssize_t>(types.size()));
309307
for (size_t i = 0; i < types.size(); i++) {
310308
PyList_SetItem(py_list, i, PyLong_FromLong(types[i]));
311309
}
@@ -319,7 +317,7 @@ PyObject* CompiledModelWrapper::GetInputBufferRequirements(int signature_index,
319317
PyObject* CompiledModelWrapper::GetOutputBufferRequirements(int signature_index,
320318
int output_index) {
321319
auto req_or = compiled_model_.GetOutputBufferRequirements(
322-
(size_t)signature_index, (size_t)output_index);
320+
static_cast<size_t>(signature_index), static_cast<size_t>(output_index));
323321
if (!req_or) {
324322
return ConvertErrorToPyExc(req_or.Error());
325323
}
@@ -333,15 +331,16 @@ PyObject* CompiledModelWrapper::GetOutputBufferRequirements(int signature_index,
333331
Py_DECREF(dict);
334332
return ConvertErrorToPyExc(size_or.Error());
335333
}
336-
PyDict_SetItemString(dict, "buffer_size", PyLong_FromLong((int64_t)*size_or));
334+
PyDict_SetItemString(dict, "buffer_size",
335+
PyLong_FromLong(static_cast<int64_t>(*size_or)));
337336

338337
auto types_or = req.SupportedTypes();
339338
if (!types_or) {
340339
Py_DECREF(dict);
341340
return ConvertErrorToPyExc(types_or.Error());
342341
}
343342
auto types = std::move(*types_or);
344-
PyObject* py_list = PyList_New((Py_ssize_t)types.size());
343+
PyObject* py_list = PyList_New(static_cast<Py_ssize_t>(types.size()));
345344
for (size_t i = 0; i < types.size(); i++) {
346345
PyList_SetItem(py_list, i, PyLong_FromLong(types[i]));
347346
}
@@ -377,7 +376,8 @@ PyObject* CompiledModelWrapper::CreateOutputBufferByName(
377376
}
378377

379378
PyObject* CompiledModelWrapper::CreateInputBuffers(int signature_index) {
380-
auto buffers_or = compiled_model_.CreateInputBuffers((size_t)signature_index);
379+
auto buffers_or =
380+
compiled_model_.CreateInputBuffers(static_cast<size_t>(signature_index));
381381
if (!buffers_or) {
382382
return ConvertErrorToPyExc(buffers_or.Error());
383383
}
@@ -387,14 +387,14 @@ PyObject* CompiledModelWrapper::CreateInputBuffers(int signature_index) {
387387
// Python owns them. Destroy on capsule destructor.
388388
PyObject* capsule =
389389
litert_wrapper_utils::MakeTensorBufferCapsule(buffers[i]);
390-
PyList_SetItem(py_list, i, capsule); // steals ref
390+
PyList_SetItem(py_list, i, capsule); // steal ref
391391
}
392392
return py_list;
393393
}
394394

395395
PyObject* CompiledModelWrapper::CreateOutputBuffers(int signature_index) {
396396
auto buffers_or =
397-
compiled_model_.CreateOutputBuffers((size_t)signature_index);
397+
compiled_model_.CreateOutputBuffers(static_cast<size_t>(signature_index));
398398
if (!buffers_or) {
399399
return ConvertErrorToPyExc(buffers_or.Error());
400400
}
@@ -415,8 +415,8 @@ PyObject* CompiledModelWrapper::RunByName(const char* signature_key,
415415
return ReportError("RunByName expects input_map & output_map as dict");
416416
}
417417

418-
absl::flat_hash_map<absl::string_view, litert::TensorBuffer> in_map;
419-
absl::flat_hash_map<absl::string_view, litert::TensorBuffer> out_map;
418+
absl::flat_hash_map<absl::string_view, TensorBuffer> in_map;
419+
absl::flat_hash_map<absl::string_view, TensorBuffer> out_map;
420420

421421
PyObject* key;
422422
PyObject* val;
@@ -436,7 +436,7 @@ PyObject* CompiledModelWrapper::RunByName(const char* signature_key,
436436
return ReportError("capsule missing pointer in input_map");
437437
}
438438
in_map[nm] =
439-
litert::TensorBuffer((LiteRtTensorBuffer)ptr, litert::OwnHandle::kNo);
439+
TensorBuffer(static_cast<LiteRtTensorBuffer>(ptr), OwnHandle::kNo);
440440
}
441441

442442
pos = 0;
@@ -455,11 +455,11 @@ PyObject* CompiledModelWrapper::RunByName(const char* signature_key,
455455
return ReportError("capsule missing pointer in output_map");
456456
}
457457
out_map[nm] =
458-
litert::TensorBuffer((LiteRtTensorBuffer)ptr, litert::OwnHandle::kNo);
458+
TensorBuffer(static_cast<LiteRtTensorBuffer>(ptr), OwnHandle::kNo);
459459
}
460460

461-
auto run_or = compiled_model_.Run(signature_key, in_map, out_map);
462-
if (!run_or) {
461+
if (auto run_or = compiled_model_.Run(signature_key, in_map, out_map);
462+
!run_or) {
463463
return ConvertErrorToPyExc(run_or.Error());
464464
}
465465
Py_RETURN_NONE;
@@ -474,8 +474,8 @@ PyObject* CompiledModelWrapper::RunByIndex(int signature_index,
474474
if (!PyList_Check(output_caps_list)) {
475475
return ReportError("RunByIndex output_caps_list not list");
476476
}
477-
std::vector<litert::TensorBuffer> inputs;
478-
std::vector<litert::TensorBuffer> outputs;
477+
std::vector<TensorBuffer> inputs;
478+
std::vector<TensorBuffer> outputs;
479479

480480
Py_ssize_t n_in = PyList_Size(input_caps_list);
481481
inputs.reserve(n_in);
@@ -489,7 +489,7 @@ PyObject* CompiledModelWrapper::RunByIndex(int signature_index,
489489
if (!ptr) {
490490
return ReportError("Missing pointer in input capsule");
491491
}
492-
inputs.emplace_back((LiteRtTensorBuffer)ptr, litert::OwnHandle::kNo);
492+
inputs.emplace_back(static_cast<LiteRtTensorBuffer>(ptr), OwnHandle::kNo);
493493
}
494494

495495
Py_ssize_t n_out = PyList_Size(output_caps_list);
@@ -504,15 +504,15 @@ PyObject* CompiledModelWrapper::RunByIndex(int signature_index,
504504
if (!ptr) {
505505
return ReportError("Missing pointer in output capsule");
506506
}
507-
outputs.emplace_back((LiteRtTensorBuffer)ptr, litert::OwnHandle::kNo);
507+
outputs.emplace_back(static_cast<LiteRtTensorBuffer>(ptr), OwnHandle::kNo);
508508
}
509509

510-
auto run_or = compiled_model_.Run((size_t)signature_index, inputs, outputs);
511-
if (!run_or) {
510+
if (auto run_or = compiled_model_.Run(static_cast<size_t>(signature_index),
511+
inputs, outputs);
512+
!run_or) {
512513
return ConvertErrorToPyExc(run_or.Error());
513514
}
514515
Py_RETURN_NONE;
515516
}
516517

517-
} // namespace compiled_model_wrapper
518-
} // namespace litert
518+
} // namespace litert::compiled_model_wrapper

litert/python/litert_wrapper/compiled_model_wrapper/compiled_model_wrapper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class CompiledModelWrapper {
136136

137137
private:
138138
// Returns the size in bytes of a single element of the given data type.
139-
size_t ByteWidthOfDType(const std::string& dtype);
139+
static size_t ByteWidthOfDType(const std::string& dtype);
140140

141141
// Reports an error to Python and returns nullptr.
142142
static PyObject* ReportError(const std::string& msg);

0 commit comments

Comments
 (0)