@@ -96,39 +96,43 @@ struct TensorMeta {
9696
9797/* *
9898 * Describes which dtype & dim order specialized kernel to be bound to an
99- * operator. If `is_fallback_` is true, it means this kernel can be used as a
100- * fallback, if false, it means this kernel can only be used if all the
101- * `TensorMeta` are matched. Fallback means this kernel will be used for
102- * all input tensor dtypes and dim orders, if the specialized kernel is not
103- * registered.
99+ * operator.
104100 *
105- * The format of a kernel key data is a string:
106- * "v<version>/<tensor_meta>|<tensor_meta>..."
107- * Size: Up to 691 1 1 1 (42 +1) * 16
108- * Assuming max number of tensors is 16 ^
109- * Kernel key version is v1 for now. If the kernel key format changes,
110- * update the version to avoid breaking pre-existing kernel keys.
111- * Example: v1/7;0,1,2,3
112- * The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3
101+ * Kernel key data is a string with the format:
102+ *
103+ * "v<version>/<tensor_meta>|<tensor_meta>..."
104+ *
105+ * The version is v1 for now. If the kernel key format changes, update the
106+ * version to avoid breaking pre-existing kernel keys.
113107 *
114108 * Each tensor_meta has the following format: "<dtype>;<dim_order,...>"
115- * Size: Up to 42 1-2 1 24 (1 byte for 0-9; 2
116- * for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example:
117- * 7;0,1,2,3 for [double; 0, 1, 2, 3]
109+ *
110+ * Example kernel key data: "v1/7;0,1,2,3|1;0,1,2,3,4,5,6,7"
111+ *
112+ * This has two tensors: the first with dtype=7 and dim order 0,1,2,3, and the
113+ * second with dtype=1 and dim order 0,1,2,3,4,5,6,7.
118114 *
119115 * IMPORTANT:
120116 * Users should not construct a kernel key manually. Instead, it should be
121117 * generated from kernel yaml.
122118 */
123119struct KernelKey {
124120 public:
121+ /* *
122+ * Creates a fallback (non-specialized) kernel key: this kernel can be used
123+ * for all input tensor dtypes and dim orders if the specialized kernel is not
124+ * registered.
125+ */
125126 KernelKey () : is_fallback_(true ) {}
126127
128+ /* *
129+ * Creates a specialized (non-fallback) kernel key that matches a specific
130+ * set of input tensor dtypes and dim orders. See the class comment for the
131+ * expected format of `kernel_key_data`.
132+ */
127133 /* implicit */ KernelKey(const char * kernel_key_data)
128134 : kernel_key_data_(kernel_key_data), is_fallback_(false ) {}
129135
130- constexpr static int MAX_SIZE = 691 ;
131-
132136 bool operator ==(const KernelKey& other) const {
133137 return this ->equals (other);
134138 }
@@ -144,7 +148,7 @@ struct KernelKey {
144148 if (is_fallback_) {
145149 return true ;
146150 }
147- return strncmp (kernel_key_data_, other.kernel_key_data_ , MAX_SIZE ) == 0 ;
151+ return strcmp (kernel_key_data_, other.kernel_key_data_ ) == 0 ;
148152 }
149153
150154 bool is_fallback () const {
@@ -194,7 +198,23 @@ struct Kernel {
194198};
195199
196200namespace internal {
197- void make_kernel_key_string (Span<const TensorMeta> key, char * buf);
201+
202+ /* *
203+ * A make_kernel_key_string buffer size that is large enough to hold a kernel
204+ * key string with 16 tensors of 16 dimensions, plus the trailing NUL byte.
205+ */
206+ constexpr size_t kKernelKeyBufSize = 659 ;
207+
208+ /* *
209+ * Given the list of input tensor dtypes + dim orders, writes the kernel key
210+ * string into the buffer. Returns an error if the buffer is too small or if the
211+ * tensors cannot be represented as a valid key string.
212+ */
213+ Error make_kernel_key_string (
214+ Span<const TensorMeta> key,
215+ char * buf,
216+ size_t buf_size);
217+
198218} // namespace internal
199219
200220/* *
0 commit comments