1+ #ifndef __CONV_MOORE_H__
2+ #define __CONV_MOORE_H__
3+
4+ #include " conv_mudnn.h"
5+
6+ namespace op ::conv::moore {
7+
8+ // Descriptor class for CONV operations on Moore devices.
9+ // This class acts as a wrapper to select mudnn backend.
10+ // It encapsulates the backend-specific Descriptor implementation and provides
11+ // a unified interface for workspace query and CONV calculation.
12+ class Descriptor final : public InfiniopDescriptor {
13+ public:
14+ // Destructor: deletes the backend-specific descriptor.
15+ ~Descriptor () {
16+ delete reinterpret_cast <mudnn::Descriptor *>(_impl);
17+ }
18+
19+ // Returns the required workspace size for the CONV operation.
20+ size_t workspaceSize () const {
21+ return reinterpret_cast <mudnn::Descriptor *>(_impl)->workspaceSize ();
22+ }
23+
24+ // Static factory method to create a Descriptor instance.
25+ // This method chooses the backend (mudnn) and constructs
26+ // the corresponding implementation internally.
27+ static infiniStatus_t create (
28+ infiniopHandle_t handle,
29+ Descriptor **desc_ptr,
30+ infiniopTensorDescriptor_t y_desc,
31+ infiniopTensorDescriptor_t x_desc,
32+ infiniopTensorDescriptor_t w_desc,
33+ infiniopTensorDescriptor_t b_desc,
34+ const void *pads,
35+ const void *strides,
36+ const void *dilations,
37+ size_t n) {
38+ auto desc = new Descriptor (handle->device , handle->device_id );
39+
40+ // Backend selection strategy:
41+ // Currently defaulting to MUDNN.
42+ // Can be modified to choose based on environment variables or runtime parameters.
43+ desc->_backend = Backend::MUDNN;
44+
45+ mudnn::Descriptor *impl;
46+ auto status = mudnn::Descriptor::create (handle, &impl, y_desc, x_desc, w_desc, b_desc, pads, strides, dilations, n);
47+ if (status != INFINI_STATUS_SUCCESS) {
48+ delete desc;
49+ return status;
50+ }
51+ desc->_impl = impl;
52+
53+ *desc_ptr = desc;
54+ return INFINI_STATUS_SUCCESS;
55+ }
56+
57+ // Unified CONV calculation interface.
58+ // Calls the corresponding backend's calculate function internally.
59+ infiniStatus_t calculate (
60+ void *workspace, size_t workspace_size,
61+ void *y,
62+ const void *x,
63+ const void *w,
64+ const void *bias,
65+ void *stream) const {
66+ return reinterpret_cast <mudnn::Descriptor *>(_impl)
67+ ->calculate (workspace, workspace_size, y, x, w, bias, stream);
68+ }
69+
70+ private:
71+ // Private constructor: ensures users cannot directly instantiate Descriptor.
72+ // Instances must be created via the static create() factory method.
73+ Descriptor (infiniDevice_t device_type, int device_id)
74+ : InfiniopDescriptor{device_type, device_id}, _impl(nullptr ) {}
75+
76+ // Enum to indicate which backend is being used internally.
77+ enum class Backend { MUDNN };
78+
79+ Backend _backend; // Currently selected MUDNN backend
80+ void *_impl; // Pointer to backend-specific descriptor (mudnn::Descriptor*)
81+ };
82+
83+ } // namespace op::conv::moore
84+
85+ #endif // __CONV_MOORE_H__
0 commit comments